diff --git a/.github/workflows/python-package-pip.yaml b/.github/workflows/inference.yaml similarity index 91% rename from .github/workflows/python-package-pip.yaml rename to .github/workflows/inference.yaml index 0b614278f..06998dc15 100644 --- a/.github/workflows/python-package-pip.yaml +++ b/.github/workflows/inference.yaml @@ -1,6 +1,6 @@ # Template: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: CoNeTTE test +name: CoNeTTE inference on: push: @@ -51,7 +51,11 @@ jobs: run: | python -m pip install -e .[dev] - # --- TESTS --- + # --- TESTS --- + - name: Check format with Black + run: | + python -m black --check --diff src + - name: Print install info run: | conette-info diff --git a/.github/workflows/training.yaml b/.github/workflows/training.yaml new file mode 100644 index 000000000..4833c5122 --- /dev/null +++ b/.github/workflows/training.yaml @@ -0,0 +1,98 @@ +# Template: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: CoNeTTE training + +on: + push: + branches: [ main, dev ] + pull_request: + branches: [ main, dev ] + +env: + CACHE_NUMBER: 0 # increase to reset cache manually + DATAROOT: "$HOME/.cache/data" + LOGROOT: "logs" + +# Cancel workflow if a new push occurs +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.10"] + + defaults: + run: + shell: bash -el {0} + + steps: + # --- INSTALLATIONS --- + - name: Checkout repository and submodules + uses: actions/checkout@v2 + with: + submodules: recursive + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install soundfile + run: | + # For soundfile dep + sudo apt-get install libsndfile1 + + - name: Install local packages + run: | + python -m pip install -e .[train] + + - name: Print install info + run: | + conette-info + + - name: Prepare spaCy models + run: | + conette-prepare data=none default=false verbose=2 spacy=true + + - name: Load prepare cache + uses: actions/cache@v3 + id: cache_preparation + with: + path: | + ~/.cache/aac-metrics + ~/.cache/conette + ~/.cache/data/HDF + ~/.cache/huggingface + ~/.cache/torch + ~/nltk_data + key: ${{ runner.os }}-cache_preparation-${{ hashFiles('src/conette/prepare.py') }} + restore-keys: | + ${{ runner.os }}-cache_preparation + + - name: Prepare data and other models if necessary + if: ${{ steps.cache_preparation.outputs.cache-hit != 'true' }} + run: | + echo "Prepare data in dataroot '$DATAROOT'" + cnext_bl_path="$HOME/.cache/torch/hub/checkpoints/convnext_tiny_465mAP_BL_AC_70kit.pth" + conette-prepare data=clotho default=true pann=false pack_to_hdf=true data.clean_archives=true data.subsets=[val] audio_t.src_sr=44100 audio_t.pretrain_path=${cnext_bl_path} post_hdf_name=bl pretag=cnext_bl csum_in_hdf_name=false path.data=$DATAROOT verbose=2 + + # --- TESTS --- + - name: Train a model + run: | + target_hdf="clotho_val_resample_mean_convnext_ident_bl.hdf" + conette-train pl=conette expt=[clotho_cnext_bl,task_ds_src_camw] dm.train_hdfs=${target_hdf} dm.val_hdfs=${target_hdf} dm.test_hdfs=${target_hdf} dm.predict_hdfs=[] trainer.accelerator=cpu enable_dspeed=false path.data=$DATAROOT verbose=2 trainer=lim2 dm.bsize=3 trainer.max_epochs=1 path.log_root=$LOGROOT + + - name: Run CoNeTTE predict with trained model + run: | + latest_parent_logdir=`ls -Art "$LOGROOT" | grep train | tail -n 1` + latest_logdir=`ls -Art "$LOGROOT/$latest_parent_logdir" | tail -n 1` + model_path=$LOGROOT/$latest_parent_logdir/$latest_logdir + echo "Predict with $model_path..." + conette-predict --audio src/conette/data/sample.wav --model_path "$model_path" diff --git a/.gitignore b/.gitignore index 75b292b7d..403ffc85b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,7 @@ __pycache__/ *.egg-info/ Labbeti/conette/ -tmp/ +*tmp/ dist/ +logs/ +data/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 9249e8efd..ab01fb6a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to this project will be documented in this file. +## [0.2.0] 2024-01-12 +### Added +- CoNeTTE training source code, with entire data processing. +- ConvNeXt-trans baseline training source code, with entire data processing. +- ConvNeXt tag logits to CoNeTTE model outputs during inference. + ## [0.1.4] 2023-11-20 ### Fixed - Fix forbid repetition mode argument. diff --git a/README.md b/README.md index 98c0e5f87..aa5eb9415 100644 --- a/README.md +++ b/README.md @@ -9,14 +9,16 @@ -CoNeTTE is an audio captioning system, which generate a short textual description of the sound events in any audio file. The architecture and training are explained in the corresponding [paper](https://arxiv.org/pdf/2309.00454.pdf). The model has been developped by me ([Étienne Labbé](https://labbeti.github.io/)) during my PhD. +CoNeTTE is an audio captioning system, which generate a short textual description of the sound events in any audio file. The architecture and training are explained in the corresponding [paper](https://arxiv.org/pdf/2309.00454.pdf). The model has been developped by me ([Étienne Labbé](https://labbeti.github.io/)) during my PhD. A simple interface to test CoNeTTE is available on [HuggingFace website](https://huggingface.co/spaces/Labbeti/conette). -## Installation +## Inference + +### Installation ```bash python -m pip install conette ``` -## Usage with python +### Usage with python ```py from conette import CoNeTTEConfig, CoNeTTEModel @@ -57,26 +59,77 @@ candidate = outputs["cands"][0] print(candidate) ``` -## Usage with command line +### Usage with command line Simply use the command `conette-predict` with `--audio PATH1 PATH2 ...` option. You can also export results to a CSV file using `--csv_export PATH`. ```bash conette-predict --audio "/your/path/to/audio.wav" ``` -## Performance +### Performance +The model has been trained on AudioCaps (AC), Clotho (CL), MACS (MA) and WavCaps (WC). The performance on the test subsets are : | Test data | SPIDEr (%) | SPIDEr-FL (%) | FENSE (%) | Vocab | Outputs | Scores | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | | AC-test | 44.14 | 43.98 | 60.81 | 309 | [Link](https://github.com/Labbeti/conette-audio-captioning/blob/main/results/conette/outputs_audiocaps_test.csv) | [Link](https://github.com/Labbeti/conette-audio-captioning/blob/main/results/conette/scores_audiocaps_test.yaml) | | CL-eval | 30.97 | 30.87 | 51.72 | 636 | [Link](https://github.com/Labbeti/conette-audio-captioning/blob/main/results/conette/outputs_clotho_eval.csv) | [Link](https://github.com/Labbeti/conette-audio-captioning/blob/main/results/conette/scores_clotho_eval.yaml) | -This model checkpoint has been trained for the Clotho dataset, but it can also reach a good performance on AudioCaps with the "audiocaps" task. +This model checkpoint has been trained with focus on the Clotho dataset, but it can also reach a good performance on AudioCaps with the "audiocaps" task. -## Limitations +### Limitations - The model expected audio sampled at **32 kHz**. The model automatically resample up or down the input audio files. However, it might give worse results, especially when using audio with lower sampling rates. - The model has been trained on audio lasting from **1 to 30 seconds**. It can handle longer audio files, but it might require more memory and give worse results. +## Train a model +### Requirements +- Intended for Ubuntu 20.04 only. Requires **java** < 1.13, **ffmpeg**, **yt-dlp**, and **zip** commands. +- Minimal recommanded GPU: GPU V100-32G. +- WavCaps dataset might requires more than 2 TB of disk storage. + +### Installation +By default, **only the inference requirements are installed for conette**. To install training requirements you need to use the following command: +```bash +python -m pip install conette[train] +``` +If you already installed conette for inference, it is **highly recommanded to create another environment** before installing conette for training. + +### Download external models and data +These steps might take a while (few hours to download and prepare everything depending on your CPU, GPU and SSD/HDD). + +First, download the ConvNeXt, NLTK and spacy models : +```bash +conette-prepare data=none default=true pack_to_hdf=false csum_in_hdf_name=false pann=false +``` + +Then download the 4 datasets used to train CoNeTTE : +```bash +cnext_bl_path="$HOME/.cache/torch/hub/checkpoints/convnext_tiny_465mAP_BL_AC.pth" +common_args="data.download=true pack_to_hdf=true audio_t=resample_mean_convnext audio_t.pretrain_path=${cnext_bl_path} post_hdf_name=bl pretag=cnext_bl" + +conette-prepare data=audiocaps audio_t.src_sr=32000 ${common_args} +conette-prepare data=clotho audio_t.src_sr=44100 ${common_args} +conette-prepare data=macs audio_t.src_sr=48000 ${common_args} +conette-prepare data=wavcaps audio_t.src_sr=32000 ${common_args} datafilter.min_audio_size=0.1 datafilter.max_audio_size=30.0 datafilter.sr=32000 +``` + +### Train a model +CNext-trans (baseline) on CL only (~3 hours on 1 GPU V100-32G) +```bash +conette-train expt=[clotho_cnext_bl] pl=baseline +``` + +CoNeTTE on AC+CL+MA+WC, specialized for CL (~4 hours on 1 GPU V100-32G) +```bash +conette-train expt=[camw_cnext_bl_for_c,task_ds_src_camw] pl=conette +``` + +CoNeTTE on AC+CL+MA+WC, specialized for AC (~3 hours on 1 GPU V100-32G) +```bash +conette-train expt=[camw_cnext_bl_for_a,task_ds_src_camw] pl=conette +``` + +**About reproducibility** : any training with AC data cannot be reproduced because a part of this data is deleted from the YouTube source, and I cannot share my own audio files. + ## Citation The preprint version of the paper describing CoNeTTE is available on arxiv: https://arxiv.org/pdf/2309.00454.pdf @@ -96,7 +149,7 @@ The preprint version of the paper describing CoNeTTE is available on arxiv: http ## Additional information - CoNeTTE stands for **Co**nv**Ne**Xt-**T**ransformer with **T**ask **E**mbedding. - Model weights are available on HuggingFace: https://huggingface.co/Labbeti/conette -- The encoder part of the architecture is based on a ConvNeXt model for audio classification, available here: https://zenodo.org/record/8020843 under the filename "convnext_tiny_465mAP_BL_AC_70kit.pth". +- The weights of the encoder part of the architecture is based on a ConvNeXt model for audio classification, available here: https://zenodo.org/record/8020843 under the filename "convnext_tiny_465mAP_BL_AC_70kit.pth". ## Contact Maintainer: diff --git a/environment-train.yaml b/environment-train.yaml new file mode 100644 index 000000000..d540b0bc5 --- /dev/null +++ b/environment-train.yaml @@ -0,0 +1,280 @@ +name: env_conette2 +channels: +- pkgs/main +dependencies: +- _libgcc_mutex=0.1 +- _openmp_mutex=5.1 +- bzip2=1.0.8 +- ca-certificates=2023.08.22 +- ld_impl_linux-64=2.38 +- libffi=3.4.4 +- libgcc-ng=11.2.0 +- libgomp=11.2.0 +- libstdcxx-ng=11.2.0 +- libuuid=1.41.5 +- ncurses=6.4 +- openssl=3.0.12 +- pip=23.3.1 +- python=3.10.13 +- readline=8.2 +- setuptools=68.0.0 +- sqlite=3.41.2 +- tk=8.6.12 +- tzdata=2023c +- wheel=0.41.2 +- xz=5.4.5 +- zlib=1.2.13 +- pip: + - aac-datasets==0.4.1 + - aac-metrics==0.5.0 + - absl-py==2.0.0 + - aiohttp==3.9.1 + - aiosignal==1.3.1 + - alembic==1.13.0 + - antlr4-python3-runtime==4.9.3 + - anyio==4.1.0 + - argon2-cffi==23.1.0 + - argon2-cffi-bindings==21.2.0 + - arrow==1.3.0 + - asttokens==2.4.1 + - async-lru==2.0.4 + - async-timeout==4.0.3 + - attrs==23.1.0 + - audiomentations==0.34.1 + - audioread==3.0.1 + - autopage==0.5.2 + - Babel==2.13.1 + - beautifulsoup4==4.12.2 + - bert-score==0.3.13 + - black==23.12.0 + - bleach==6.1.0 + - blis==0.7.11 + - bokeh==3.3.2 + - Brotli==1.1.0 + - cachetools==5.3.2 + - catalogue==2.0.10 + - certifi==2023.11.17 + - cffi==1.16.0 + - charset-normalizer==3.3.2 + - click==8.1.7 + - cliff==4.4.0 + - cloudpathlib==0.16.0 + - cmaes==0.10.0 + - cmd2==2.4.3 + - colorlog==6.8.0 + - comm==0.2.0 + - confection==0.1.4 + - contourpy==1.2.0 + - cycler==0.12.1 + - cymem==2.0.8 + - daal==2024.0.1 + - daal4py==2024.0.1 + - debugpy==1.8.0 + - decorator==5.1.1 + - deepspeed==0.9.5 + - defusedxml==0.7.1 + - exceptiongroup==1.2.0 + - executing==2.0.1 + - fastjsonschema==2.19.0 + - filelock==3.13.1 + - flake8==6.1.0 + - fonttools==4.46.0 + - fqdn==1.5.1 + - frozenlist==1.4.0 + - fsspec==2023.12.2 + - gensim==4.3.2 + - google-auth==2.25.2 + - google-auth-oauthlib==1.1.0 + - greenlet==3.0.2 + - grpcio==1.60.0 + - h5py==3.10.0 + - hjson==3.1.0 + - huggingface-hub==0.19.4 + - hydra-colorlog==1.2.0 + - hydra-core==1.3.2 + - hydra-optuna-sweeper==1.2.0 + - idna==3.6 + - imageio==2.33.1 + - importlib-metadata==7.0.0 + - inflate64==1.0.0 + - iniconfig==2.0.0 + - intel-extension-for-pytorch==2.1.0 + - ipykernel==6.27.1 + - ipython==8.18.1 + - ipywidgets==8.1.1 + - isoduration==20.11.0 + - jedi==0.19.1 + - Jinja2==3.1.2 + - joblib==1.3.2 + - json5==0.9.14 + - jsonpointer==2.4 + - jsonschema==4.20.0 + - jsonschema-specifications==2023.11.2 + - julius==0.2.7 + - jupyter==1.0.0 + - jupyter-console==6.6.3 + - jupyter-events==0.9.0 + - jupyter-lsp==2.2.1 + - jupyter_client==8.6.0 + - jupyter_core==5.5.0 + - jupyter_server==2.12.1 + - jupyter_server_terminals==0.5.0 + - jupyterlab==4.0.9 + - jupyterlab-widgets==3.0.9 + - jupyterlab_pygments==0.3.0 + - jupyterlab_server==2.25.2 + - kiwisolver==1.4.5 + - langcodes==3.3.0 + - language-tool-python==2.7.1 + - lazy_loader==0.3 + - librosa==0.10.1 + - lightning-utilities==0.10.0 + - llvmlite==0.41.1 + - Mako==1.3.0 + - Markdown==3.5.1 + - MarkupSafe==2.1.3 + - matplotlib==3.8.2 + - matplotlib-inline==0.1.6 + - mccabe==0.7.0 + - mistune==3.0.2 + - msgpack==1.0.7 + - multidict==6.0.4 + - multivolumefile==0.2.3 + - murmurhash==1.0.10 + - mypy-extensions==1.0.0 + - nbclient==0.9.0 + - nbconvert==7.12.0 + - nbformat==5.9.2 + - nest-asyncio==1.5.8 + - networkx==3.2.1 + - ninja==1.11.1.1 + - nltk==3.8.1 + - nnAudio==0.3.2 + - notebook==7.0.6 + - notebook_shim==0.2.3 + - numba==0.58.1 + - numpy==1.26.2 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - oauthlib==3.2.2 + - omegaconf==2.3.0 + - optuna==2.10.1 + - overrides==7.4.0 + - packaging==23.2 + - pandas==2.1.4 + - pandocfilters==1.5.0 + - parso==0.8.3 + - pathspec==0.12.1 + - pbr==6.0.0 + - pexpect==4.9.0 + - Pillow==10.1.0 + - platformdirs==4.1.0 + - pluggy==1.3.0 + - pooch==1.8.0 + - preshed==3.0.9 + - prettytable==3.9.0 + - prometheus-client==0.19.0 + - prompt-toolkit==3.0.41 + - protobuf==4.23.4 + - psutil==5.9.6 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - py-cpuinfo==9.0.0 + - py7zr==0.20.8 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pybcj==1.0.2 + - pycodestyle==2.11.1 + - pycparser==2.21 + - pycryptodomex==3.19.0 + - pydantic==1.10.13 + - pyemd==1.0.0 + - pyflakes==3.1.0 + - Pygments==2.17.2 + - pyparsing==3.1.1 + - pyperclip==1.8.2 + - pyppmd==1.1.0 + - pytest==7.4.3 + - python-dateutil==2.8.2 + - python-json-logger==2.0.7 + - pytorch-lightning==1.9.5 + - pytorch-ranger==0.1.1 + - pytz==2023.3.post1 + - PyYAML==6.0.1 + - pyzmq==25.1.2 + - pyzstd==0.15.9 + - qtconsole==5.5.1 + - QtPy==2.4.1 + - referencing==0.32.0 + - regex==2023.10.3 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - resampy==0.2.2 + - rfc3339-validator==0.1.4 + - rfc3986-validator==0.1.1 + - rpds-py==0.13.2 + - rsa==4.9 + - safetensors==0.4.1 + - scikit-image==0.22.0 + - scikit-learn==1.3.2 + - scikit-learn-intelex==2024.0.1 + - scipy==1.11.4 + - Send2Trash==1.8.2 + - sentence-transformers==2.2.2 + - sentencepiece==0.1.99 + - six==1.16.0 + - smart-open==6.4.0 + - sniffio==1.3.0 + - soundfile==0.12.1 + - soupsieve==2.5 + - soxr==0.3.7 + - spacy==3.7.2 + - spacy-legacy==3.0.12 + - spacy-loggers==1.0.5 + - SQLAlchemy==2.0.23 + - srsly==2.4.8 + - stack-data==0.6.3 + - stevedore==5.1.0 + - tbb==2021.11.0 + - tensorboard==2.15.1 + - tensorboard-data-server==0.7.2 + - terminado==0.18.0 + - texttable==1.7.0 + - thinc==8.2.1 + - threadpoolctl==3.2.0 + - tifffile==2023.12.9 + - timm==0.9.12 + - tinycss2==1.2.1 + - tokenizers==0.13.3 + - tomli==2.0.1 + - torch==1.13.1 + - torch-optimizer==0.3.0 + - torchaudio==0.13.1 + - torchlibrosa==0.1.0 + - torchmetrics==1.2.1 + - torchopenl3==1.0.1 + - torchtext==0.14.1 + - torchvision==0.14.1 + - tornado==6.4 + - tqdm==4.66.1 + - traitlets==5.14.0 + - transformers==4.30.2 + - typer==0.9.0 + - types-python-dateutil==2.8.19.14 + - typing_extensions==4.9.0 + - tzdata==2023.3 + - uri-template==1.3.0 + - urllib3==2.1.0 + - wasabi==1.1.2 + - wcwidth==0.2.12 + - weasel==0.3.4 + - webcolors==1.13 + - webencodings==0.5.1 + - websocket-client==1.7.0 + - Werkzeug==3.0.1 + - widgetsnbextension==4.0.9 + - xyzservices==2023.10.1 + - yarl==1.9.4 + - zipp==3.17.0 diff --git a/pyproject.toml b/pyproject.toml index 1f927737a..86e4ce000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,21 +20,7 @@ classifiers = [ maintainers = [ {name = "Etienne Labbé (Labbeti)", email = "labbeti.pub@gmail.com"}, ] -dynamic = ["version"] - -dependencies = [ - "setuptools", - "pyyaml", - "torch>=1.10", - "torchaudio", - "torchlibrosa", - "pytorch-lightning>=1.7,<2.0", - "nltk", - "spacy", - "tensorboard", - "transformers>=4.35.0", - "omegaconf", -] +dynamic = ["version", "dependencies", "optional-dependencies"] [project.urls] Repository = "https://github.com/Labbeti/conette-audio-captioning.git" @@ -43,14 +29,8 @@ Changelog = "https://github.com/Labbeti/conette-audio-captioning/blob/main/CHANG [project.scripts] conette-info = "conette.info:print_install_info" conette-predict = "conette.predict:main_predict" - -[project.optional-dependencies] -dev = [ - "pytest", - "flake8", - "black", - "ipykernel", -] +conette-train = "conette.train:main_train" +conette-prepare = "conette.prepare:main_prepare" [tool.setuptools.packages.find] where = ["src"] # list of folders that contain the packages (["."] by default) @@ -58,3 +38,8 @@ include = ["conette*"] # package names should match these glob patterns (["*"] [tool.setuptools.dynamic] version = {attr = "conette.__version__"} +dependencies = {file = ["requirements.txt"]} +optional-dependencies = { dev = { file = ["requirements-dev.txt"] }, train = { file = ["requirements-train.txt"]}} + +[tool.ruff] +ignore = ["E501", "E402"] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..5734791b6 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +pytest +flake8 +black +ipykernel diff --git a/requirements-train.txt b/requirements-train.txt new file mode 100644 index 000000000..45276ac4d --- /dev/null +++ b/requirements-train.txt @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +torchvision +torchtext +torchmetrics +pip +hydra-core +h5py +sentencepiece +py7zr +jupyter +flake8 +pytest +bokeh +sentence-transformers +numpy +scikit-image +gensim +click +pytz +pyemd +pydantic>=1.8.2 +matplotlib +scikit-learn-intelex +audiomentations +bert-score +black +torch-optimizer +torchopenl3 +hydra-optuna-sweeper +hydra-colorlog +timm +audiomentations +language_tool_python +hydra_colorlog +aac-datasets~=0.4.0 +aac-metrics~=0.5.3 +deepspeed==0.9.5 +intel_extension_for_pytorch diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..aaee48de7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +setuptools +pyyaml +torch>=1.10,<2.0 +torchaudio +torchlibrosa +pytorch-lightning>=1.7,<2.0 +nltk +spacy +tensorboard +transformers +omegaconf diff --git a/src/conette/__init__.py b/src/conette/__init__.py index 72b4da9bf..db3195f3c 100644 --- a/src/conette/__init__.py +++ b/src/conette/__init__.py @@ -10,7 +10,7 @@ __license__ = "MIT" __maintainer__ = "Etienne Labbé (Labbeti)" __status__ = "Development" -__version__ = "0.1.4" +__version__ = "0.2.0" from pathlib import Path diff --git a/src/conette/callbacks/__init__.py b/src/conette/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/conette/callbacks/aac_evaluator.py b/src/conette/callbacks/aac_evaluator.py new file mode 100644 index 000000000..4d11d022a --- /dev/null +++ b/src/conette/callbacks/aac_evaluator.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv +import logging +import os +import os.path as osp +import tempfile +import time + +from typing import Any, Optional, Union + +import torch +import yaml + +from pytorch_lightning import LightningModule +from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.loggers import TensorBoardLogger +from torch import Tensor +from torch.utils.data.dataloader import DataLoader + +from aac_metrics.utils.checks import is_mono_sents, is_mult_sents +from aac_metrics.utils.collections import flat_list, unflat_list + +from conette.metrics.classes.all_metrics import AllMetrics +from conette.nn.functional.misc import move_to_rec +from conette.tokenization.aac_tokenizer import AACTokenizer +from conette.utils.collections import all_eq +from conette.utils.custom_logger import CustomTensorboardLogger +from conette.utils.dcase import export_to_dcase_task6a_csv +from conette.utils.log_utils import warn_once + + +pylog = logging.getLogger(__name__) + + +class AACEvaluator(Callback): + """Callback which stores candidates and references during testing to produce AAC scores. + + Include metrics : BLEU1, BLEU2, BLEU3, BLEU4, METEOR, ROUGE-L, CIDEr, SPICE, SPIDEr. + """ + + CANDS_PREFIX = "cands" + MREFS_KEY = "mrefs" + + def __init__( + self, + subrun_path: Optional[str], + test_tokenizer: AACTokenizer, + cache_path: str = "~/.cache", + java_path: str = "java", + tmp_path: str = tempfile.gettempdir(), + ckpt_name: str = "unk", + verbose: int = 1, + debug: bool = False, + save_to_csv: bool = True, + save_dcase_csv_file: bool = False, + metric_device: Union[str, torch.device, None] = None, + cpus: Optional[int] = None, + ) -> None: + if subrun_path is not None: + subrun_path = osp.expandvars(subrun_path) + + super().__init__() + self._subrun_dir = subrun_path + self._test_tokenizer = test_tokenizer + self._cache_path = cache_path + self._java_path = java_path + self._tmp_path = tmp_path + self._model_name = ckpt_name + self._verbose = verbose + self._debug = debug + self._save_to_csv = save_to_csv + self._save_dcase_csv_file = save_dcase_csv_file + self._metric_device = metric_device + self._cpus = cpus + + self._all_outputs: dict[int, dict[str, Any]] = {} + + # Note : we avoid compute scores for + # - AudioCaps/train because it is too large + # - Clotho/test because it does not have any references + # - Clotho/anasysis because it does not have any references + self._excluded_datasubsets_metrics = ( + "audiocaps_train", + "clotho_test", + "clotho_analysis", + ) + self._all_metrics = None + + # Callback methods + def on_predict_epoch_start(self, trainer, pl_module) -> None: + self._all_outputs = {} + if self._verbose >= 1: + pylog.debug(f"Starting PREDICT epoch with model_name='{self._model_name}'") + + def on_predict_batch_end( + self, + trainer, + pl_module, + outputs: dict[str, Any], + batch: dict[str, Any], + batch_idx: int, + dataloader_idx: int, + ) -> None: + return self.on_test_batch_end( + trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ) + + def on_predict_epoch_end( + self, + trainer, + pl_module: LightningModule, + outputs, + ) -> None: + if not self._save_to_csv: + return None + + for outputs_ in self._all_outputs.values(): + datasubset = _get_datasubset_name(outputs_) + + if self._subrun_dir is not None and osp.isdir(self._subrun_dir): + self._save_outputs_to_csv( + self._subrun_dir, + datasubset, + outputs_, + {}, + ) + else: + pylog.error( + f"Cannot save outputs to CSV because logdir is not a valid directory. (logdir={self._subrun_dir}, {datasubset=})" + ) + + def on_test_start(self, trainer, pl_module: LightningModule) -> None: + if self._all_metrics is None: + if self._metric_device is not None: + device = self._metric_device + else: + device = pl_module.device + + if hasattr(pl_module, "tokenizer") and isinstance( + pl_module.tokenizer, AACTokenizer + ): + train_vocab = pl_module.tokenizer.get_vocab() + else: + train_vocab = None + + self._all_metrics = AllMetrics( + preprocess=False, + device=device, + cache_path=self._cache_path, + java_path=self._java_path, + tmp_path=self._tmp_path, + meteor_java_max_memory="2G", + spice_java_max_memory="8G", + spice_n_threads=self._cpus, + spice_timeout=[3600], + train_vocab=train_vocab, + verbose=self._verbose, + ) + + if self._verbose >= 2: + pylog.debug(f"{len(self._all_metrics)} metrics has been initialized.") + + datamodule = trainer.datamodule # type: ignore + if datamodule is not None: + test_loaders = datamodule.test_dataloader() + else: + test_loaders = [] + + assert isinstance(test_loaders, list) and all( + isinstance(loader, DataLoader) for loader in test_loaders + ) + sizes = tuple(map(len, test_loaders)) + pylog.debug(f"Test loader sizes: {sizes}") + + def on_test_epoch_start(self, trainer, pl_module) -> None: + self._all_outputs = {} + if self._verbose >= 1: + pylog.debug(f"Starting TEST epoch with model_name='{self._model_name}'") + + def on_test_batch_end( + self, + trainer, + pl_module, + outputs: dict[str, Any], + batch: dict[str, Any], + batch_idx: int, + dataloader_idx: int, + ) -> None: + if outputs is None: + warn_once("Lightning module has returned None during test epoch.", pylog) + return None + + outputs = move_to_rec(outputs, device=torch.device("cpu")) + + if dataloader_idx not in self._all_outputs.keys(): + self._all_outputs[dataloader_idx] = {} + + for key, batch_values in outputs.items(): + if key not in self._all_outputs[dataloader_idx].keys(): + self._all_outputs[dataloader_idx][key] = [] + self._all_outputs[dataloader_idx][key] += list(batch_values) + + for key in ("fname", "index", "dataset", "subset"): + if key not in batch.keys(): + raise ValueError(f"Cannot find {key=} in batch.") + if key not in self._all_outputs[dataloader_idx].keys(): + self._all_outputs[dataloader_idx][key] = [] + self._all_outputs[dataloader_idx][key] += batch[key] + + def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: + datasubsets = [] + for outputs in self._all_outputs.values(): + # Sanity check + n_items = len(next(iter(outputs.values()))) + invalid_sizes_keys = [ + key for key, values in outputs.items() if len(values) != n_items + ] + if len(invalid_sizes_keys) > 0: + sizes = [len(outputs[key]) for key in invalid_sizes_keys] + raise RuntimeError( + f"Invalid number of values for keys={invalid_sizes_keys} (expected {n_items} but found {sizes=})." + ) + + datasubset = _get_datasubset_name(outputs) + counter = datasubsets.count(datasubset) + if counter > 0: + old_datasubset = datasubset + datasubset = f"{datasubset}_{counter+1}" + pylog.error( + f"Found duplicated subset '{old_datasubset}'. Renaming to '{datasubset}'." + ) + assert datasubset not in datasubsets + datasubsets.append(datasubset) + + # Tokenize candidates and references + sents_keys = [ + key + for key in outputs.keys() + if key.startswith(self.CANDS_PREFIX) or key == self.MREFS_KEY + ] + + if self._verbose >= 2: + pylog.debug( + f"Process sentences with tokenizer... ({tuple(sents_keys)=}" + ) + + for key in sents_keys: + raw_sents = outputs[key] + + if is_mono_sents(raw_sents): + sents = self._test_tokenizer.tokenize_batch(raw_sents) + sents = self._test_tokenizer.detokenize_batch(sents) + + elif is_mult_sents(raw_sents): + flat_raw_sents, sizes = flat_list(raw_sents) + flat_sents = self._test_tokenizer.tokenize_batch(flat_raw_sents) + flat_sents = self._test_tokenizer.detokenize_batch(flat_sents) + sents = unflat_list(flat_sents, sizes) + + else: + raise TypeError(f"Cannot detect sentences type. (with {key=})") + + outputs[key] = sents + + if self._verbose >= 2: + pylog.debug(f"Sentences processed. ({tuple(sents_keys)=})") + + sents_scores = {} + if datasubset not in self._excluded_datasubsets_metrics: + with torch.inference_mode(): + corpus_scores, sents_scores = self._compute_metrics( + outputs, datasubset + ) + + if self._verbose >= 1: + pylog.info( + f"Global scores for dataset {datasubset}:\n{yaml.dump(corpus_scores, sort_keys=False)}" + ) + self._log_global_scores(corpus_scores, datasubset, pl_module) + if self._verbose >= 1: + self._print_example(outputs, datasubset, pl_module, sents_scores) + else: + pylog.debug(f"Skipping metrics for subset '{datasubset}'...") + + if self._save_to_csv: + if self._subrun_dir is not None and osp.isdir(self._subrun_dir): + self._save_outputs_to_csv( + self._subrun_dir, + datasubset, + outputs, + sents_scores, + ) + else: + pylog.error( + f"Cannot save outputs to CSV because logdir is not a valid directory. (logdir={self._subrun_dir}, {datasubset=})" + ) + + self._all_outputs = {} + + # AACEvaluator methods + def set_model_name(self, model_name: str) -> None: + self._model_name = model_name + + def _compute_metrics( + self, + outputs: dict[str, list], + datasubset: str, + ) -> tuple[dict[str, float], dict[str, list[float]]]: + corpus_scores = {} + sents_scores = {} + + if self._all_metrics is None: + return corpus_scores, sents_scores + + start_time = time.perf_counter() + + pred_keys = [ + key + for key, values in outputs.items() + if key.startswith(self.CANDS_PREFIX) + and isinstance(values, list) + and all(isinstance(value, str) for value in values) + ] + all_mrefs = outputs[self.MREFS_KEY] + + if self._verbose >= 1: + n_metrics = len(self._all_metrics) + pylog.info( + f"Start computing metrics... ({datasubset=}, n_preds={len(all_mrefs)}, n_preds_types={len(pred_keys)}, {n_metrics=})" + ) + + for pred_key in pred_keys: + all_cands = outputs[pred_key] + + if self._verbose >= 1: + pylog.debug( + f"Computing sentence level metrics... ({datasubset=}, {pred_key=})" + ) + + pred_global_scores, pred_sents_scores = self._all_metrics( + all_cands, + all_mrefs, + ) + corpus_scores |= { + f"{self._model_name}.{pred_key}.{metric_name}": score + for metric_name, score in pred_global_scores.items() + } + sents_scores |= { + f"{self._model_name}.{pred_key}.{metric_name}": scores + for metric_name, scores in pred_sents_scores.items() + } + + if self._verbose >= 1: + end_time = time.perf_counter() + duration_s = end_time - start_time + pylog.info( + f"Computing metrics finished in {duration_s:.2f}s. ({datasubset=})" + ) + + # Sanity check + if __debug__: + invalid_corpus_scores = tuple( + [ + name + for name, scores in corpus_scores.items() + if not isinstance(scores, Tensor) or scores.ndim != 0 + ] + ) + invalid_sents_scores = tuple( + [ + name + for name, scores in sents_scores.items() + if not isinstance(scores, Tensor) + or scores.ndim != 1 + or scores.shape[0] != len(all_mrefs) + ] + ) + if len(invalid_corpus_scores) > 0: + raise ValueError( + f"Invalid global scores. (found {invalid_corpus_scores=})" + ) + + if len(invalid_sents_scores) > 0: + raise ValueError( + f"Invalid local scores. (found {invalid_sents_scores=})" + ) + + corpus_scores = {name: score.item() for name, score in corpus_scores.items()} + sents_scores = {name: scores.tolist() for name, scores in sents_scores.items()} + + return corpus_scores, sents_scores + + def _log_global_scores( + self, + corpus_scores: dict[str, float], + datasubset: str, + pl_module: LightningModule, + ) -> None: + global_scores_with_datasubset = { + f"{datasubset}/{key}": score for key, score in corpus_scores.items() + } + for pl_logger in pl_module.loggers: + if isinstance(pl_logger, CustomTensorboardLogger): + pl_logger.log_hyperparams( + params={}, metrics=global_scores_with_datasubset + ) + pl_logger.update_files() + + def _print_example( + self, + outputs: dict[str, list], + datasubset: str, + pl_module: LightningModule, + sents_scores: dict[str, list[float]], + ) -> None: + assert self._test_tokenizer is not None + n_outputs = len(outputs["fname"]) + indexes = torch.randint(0, n_outputs, (1,)).tolist() + + pylog.info( + f"Show {len(indexes)} example(s) with model_name={self._model_name} : " + ) + + for idx in indexes: + fname = outputs["fname"][idx] + dset_index = outputs["index"][idx] + candidates = { + key: candidates_sents[idx] + for key, candidates_sents in outputs.items() + if key.startswith(self.CANDS_PREFIX) + } + mult_references = outputs[self.MREFS_KEY][idx] + + lines = "-" * 10 + width = 128 + + local_main_metrics = { + k: v[idx] for k, v in sents_scores.items() if "spider" in k + } + + infos = { + "datasubset": datasubset, + "index": dset_index, + "fname": fname, + } | local_main_metrics + + pylog.info( + f"\n" + f"{lines}\nInfos\n{lines}\n{yaml.dump(infos, width=width, sort_keys=False)}" + f"{lines}\nCandidates\n{lines}\n{yaml.dump(candidates, width=width, sort_keys=False)}" + f"{lines}\nReferences\n{lines}\n{yaml.dump(mult_references, width=width, sort_keys=False)}" + ) + + # Log examples + loggers = pl_module.loggers + for logger in loggers: + if isinstance(logger, TensorBoardLogger): + prefix = logger.name + logger.experiment.add_text( + f"{prefix}/{datasubset}_cands_{dset_index}", + yaml.dump(candidates, sort_keys=False), + ) + logger.experiment.add_text( + f"{prefix}/{datasubset}_mrefs_{dset_index}", + yaml.dump(mult_references, sort_keys=False), + ) + + def _save_outputs_to_csv( + self, + dpath: str, + datasubset: str, + outs: dict[str, list], + sents_scores: dict[str, list[float]], + ) -> None: + # Sanity check + lens = list(map(len, outs.values())) + list(map(len, sents_scores.values())) + assert all_eq(lens), f"{lens=}" + + n_items = lens[0] + csv_fname = f"{self._model_name}_outputs_{datasubset}.csv" + csv_fpath = osp.join(dpath, csv_fname) + + def process(key: str, value: Any) -> Any: + if isinstance(value, Tensor): + return value.tolist() + else: + return value + + csv_all_values = outs | sents_scores + + with open(csv_fpath, "w") as file: + keys = list(csv_all_values.keys()) + writer = csv.DictWriter(file, fieldnames=keys) + writer.writeheader() + + for i in range(n_items): + row = {key: values[i] for key, values in csv_all_values.items()} + row = {key: process(key, value) for key, value in row.items()} + writer.writerow(row) + + if self._save_dcase_csv_file: + fnames = outs["fname"] + mcands = {k: v for k, v in outs.items() if k.startswith(self.CANDS_PREFIX)} + + dcase_dpath = osp.join(dpath, "dcase") + os.makedirs(dcase_dpath, exist_ok=True) + + for cands_name, cands in mcands.items(): + if len(mcands) == 1: + dcase_fname = ( + f"submission_output_{self._model_name}_{datasubset}.csv" + ) + else: + dcase_fname = f"submission_output_{self._model_name}_{datasubset}_{cands_name}.csv" + + dcase_fpath = osp.join(dcase_dpath, dcase_fname) + export_to_dcase_task6a_csv(dcase_fpath, fnames, cands) + + +def _get_datasubset_name(outputs: dict[str, Any]) -> str: + datanames = list(sorted(set(map(str.lower, outputs["dataset"])))) + subsets = list(sorted(set(map(str.lower, outputs["subset"])))) + if len(datanames) == 1 and len(subsets) == 1: + datasubset = f"{datanames[0]}_{subsets[0]}" + else: + datasubset = f"mix_{'_'.join(datanames)}_{'_'.join(subsets)}" + return datasubset diff --git a/src/conette/callbacks/aac_validator.py b/src/conette/callbacks/aac_validator.py new file mode 100644 index 000000000..dac26308f --- /dev/null +++ b/src/conette/callbacks/aac_validator.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Any, Iterable, Optional, Union + +import torch + +from pytorch_lightning import LightningModule +from pytorch_lightning.callbacks.callback import Callback +from torch import nn + +from aac_metrics.classes.cider_d import CIDErD +from aac_metrics.classes.fense import FENSE + +from conette.metrics.classes.diversity import Diversity +from conette.metrics.classes.text_stats import TextStats +from conette.nn.functional.get import get_device + + +class AACValidator(Callback): + def __init__( + self, + monitors: Union[str, Iterable[str]], + metrics_keys: Union[str, Iterable[str]] = (), + computation_device: Union[str, torch.device, None] = "auto", + other_device: Union[str, torch.device, None] = "cpu", + build_on_start: bool = False, + ) -> None: + if isinstance(metrics_keys, str): + metrics_keys = [metrics_keys] + else: + metrics_keys = list(metrics_keys) + + computation_device = get_device(computation_device) + other_device = get_device(other_device) + + if isinstance(monitors, str): + monitors = [monitors] + else: + monitors = list(monitors) + + super().__init__() + self._monitors = monitors + self._metrics_keys = metrics_keys + self._computation_device = computation_device + self._other_device = other_device + + self._cands_dic: dict[str, list[str]] = {} + self._mrefs_lst: list[list[str]] = [] + self._metrics = {} + + if build_on_start: + self.__build_metrics(computation_device) + + # Callback methods + def on_fit_start(self, trainer, pl_module) -> None: + if len(self._metrics) == 0: + self.__build_metrics(pl_module.device) + + def on_fit_end(self, trainer, pl_module) -> None: + del self._metrics + self._metrics = {} + + def on_train_batch_end( + self, + trainer, + pl_module, + outputs: Optional[dict[str, Any]], + batch, + batch_idx, + unused=0, + ) -> None: + self.__on_batch_end(outputs) + + def on_validation_batch_end( + self, + trainer, + pl_module, + outputs: Optional[dict[str, Any]], + batch, + batch_idx, + dataloader_idx, + ) -> None: + self.__on_batch_end(outputs) + + def on_train_epoch_end(self, trainer, pl_module: LightningModule) -> None: + self.__on_epoch_end(pl_module, "train/") + self._cands_dic = {} + self._mrefs_lst = [] + + def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None: + self.__on_epoch_end(pl_module, "val/") + self._cands_dic = {} + self._mrefs_lst = [] + + # Other methods + def __build_metrics( + self, + computation_device: Union[str, torch.device, None], + ) -> None: + metrics: dict[str, nn.Module] = { + "cider_d": CIDErD(return_all_scores=True), + "div1": Diversity(return_all_scores=True), + "stats": TextStats(return_all_scores=True), + } + + if computation_device is None: + computation_device = self._computation_device + else: + self._computation_device = get_device(computation_device) + + if ( + any("fense" in monitor for monitor in self._monitors) + or "fense" in self._metrics_keys + ): + fense = FENSE( + return_all_scores=True, + device=computation_device, + ) + metrics["fense"] = fense + + self._metrics = metrics + self._metrics = { + name: metric.to(device=self._other_device) + for name, metric in self._metrics.items() + } + + def __on_batch_end(self, outputs: Optional[dict[str, Any]]) -> None: + if outputs is None or not isinstance(outputs, dict): + return None + + cands_dic: dict[str, list[str]] = { + name: values for name, values in outputs.items() if name.startswith("cands") + } + refs: Optional[list[str]] = outputs.get("refs") + mrefs: Optional[list[list[str]]] = outputs.get("mrefs") + + if len(cands_dic) == 0: + return None + if (refs is None) == (mrefs is None): + raise RuntimeError( + f"Invalid batch output with ({refs is None=}, {mrefs is None=}). (expected (None, [...]) or ([...], None))" + ) + + if mrefs is None: + if refs is None: + raise RuntimeError( + f"Found candidates but no references. ({cands_dic=})" + ) + mrefs = [[ref] for ref in refs] + + for key, cands_lst in cands_dic.items(): + if key in self._cands_dic: + self._cands_dic[key] += cands_lst + else: + self._cands_dic[key] = cands_lst + + self._mrefs_lst += mrefs + + def __on_epoch_end( + self, + pl_module: LightningModule, + prefix: str, + ) -> None: + if any( + len(cands_lst) != len(self._mrefs_lst) + for cands_lst in self._cands_dic.values() + ): + cands_lens = list(map(len, self._cands_dic.values())) + mrefs_lens = [len(self._mrefs_lst)] * len(cands_lens) + raise ValueError( + f"Invalid number of candidates and references. (found {cands_lens=} != {mrefs_lens=})" + ) + + if len(self._cands_dic) <= 0: + return None + + self._metrics = { + name: metric.to(device=self._computation_device) + for name, metric in self._metrics.items() + } + + if not hasattr(pl_module, "tokenizer"): + raise RuntimeError("Cannot find tokenizer in pl_module.") + tokenizer: Any = pl_module.tokenizer # type: ignore + mrefs_lst = tokenizer.detokenize_rec(tokenizer.tokenize_rec(self._mrefs_lst)) + + if len(self._cands_dic) == 1: + cands_lst = next(iter(self._cands_dic.values())) + + scores = {} + for metric in self._metrics.values(): + scores |= metric(cands_lst, mrefs_lst)[0] + + scores_dic = {"": scores} + # scores_dic ex: {"": {"fer": tensor(0.1), "fense": tensor(0.3)}} + else: + scores_dic = {} + for cand_name, cands_lst in self._cands_dic.items(): + scores = {} + for metric in self._metrics.values(): + scores |= metric(cands_lst, mrefs_lst)[0] + scores_dic[cand_name] = scores + # scores_dic ex: {"cands.": {"fer": tensor(0.1), "fense": tensor(0.3)}, ...} + + scores = { + f"{prefix}{k}{metric_name}": score + for k, corpus_scores in scores_dic.items() + for metric_name, score in corpus_scores.items() + } + + monitors_scores = {} + for monitor in self._monitors: + score = scores.pop(monitor, None) + if score is not None: + monitors_scores[monitor] = score + + bar_scores = monitors_scores + non_bar_scores = { + name: score for name, score in scores.items() if name not in bar_scores + } + + log_kwargs: dict[str, Any] = dict(on_step=False, on_epoch=True, sync_dist=True) + pl_module.log_dict(bar_scores, prog_bar=True, **log_kwargs) + pl_module.log_dict(non_bar_scores, prog_bar=False, **log_kwargs) + + self._metrics = { + name: metric.to(device=self._other_device) + for name, metric in self._metrics.items() + } diff --git a/src/conette/callbacks/custom_ckpt.py b/src/conette/callbacks/custom_ckpt.py new file mode 100644 index 000000000..f56b156f2 --- /dev/null +++ b/src/conette/callbacks/custom_ckpt.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import math +import os +import os.path as osp +import re + +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from lightning_utilities.core.rank_zero import rank_zero_info + + +pylog = logging.getLogger(__name__) + + +class CustomModelCheckpoint(ModelCheckpoint): + """Custom Model Checkpoint class. + + Changes: + - checkpoint filenames use '-' instead of '=' for separate name and values in checkpoint names + It help for avoiding errors with hydra which also uses the character '=' between arguments and values + - replace "/" from metrics names by "_" in chekcpoint filenames to avoid errors with metrics like "val/loss" in checkpoint filename + - create a symlink "best.ckpt" to the best model path + - track the best monitor candidates. (method 'get_best_monitor_candidates()') + - option to save checkpoint after only a certain epoch (arg 'save_after_epoch') + + Example : + With ModelCheckpoint : + epoch=0-step=479-val/loss=3.3178.ckpt + With CustomModelCheckpoint : + epoch_0-step_479-val_loss_3.3178.ckpt + """ + + CHECKPOINT_JOIN_CHAR = "-" + CHECKPOINT_SEP_CHAR = "_" + + def __init__( + self, + # Herited args + dirpath: Optional[Any] = None, + filename: Optional[str] = None, + monitor: Optional[str] = None, + verbose: bool = False, + save_last: Optional[bool] = None, + save_top_k: int = 1, + save_weights_only: bool = False, + mode: str = "min", + auto_insert_metric_name: bool = True, + every_n_train_steps: Optional[int] = None, + train_time_interval: Optional[timedelta] = None, + every_n_epochs: Optional[int] = None, + save_on_train_epoch_end: Optional[bool] = None, + # New args + log_best_score: bool = True, + save_after_epoch: Union[None, int, float] = None, + create_symlink: bool = True, + ) -> None: + if isinstance(dirpath, (str, Path)): + dirpath = osp.expandvars(dirpath) + dirpath = osp.expanduser(dirpath) + + super().__init__( + dirpath=dirpath, + filename=filename, + monitor=monitor, + verbose=verbose, + save_last=save_last, + save_top_k=save_top_k, + save_weights_only=save_weights_only, + mode=mode, + auto_insert_metric_name=auto_insert_metric_name, + every_n_train_steps=every_n_train_steps, + train_time_interval=train_time_interval, + every_n_epochs=every_n_epochs, + save_on_train_epoch_end=save_on_train_epoch_end, + ) + self.log_best_score = log_best_score + self.save_after_epoch = save_after_epoch + self.create_symlink = create_symlink + + self._best_monitor_candidates = {} + + @classmethod + def _format_checkpoint_name( + cls, + filename: Optional[str], + metrics: Dict[str, Any], + prefix: str = "", + auto_insert_metric_name: bool = True, + ) -> str: + if not filename: + # filename is not set, use default name + filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}" + + # check and parse user passed keys in the string + groups = re.findall(r"(\{.*?)[:\}]", filename) + if len(groups) >= 0: + for group in groups: + name = group[1:] + + if auto_insert_metric_name: + # Note source: + # filename = filename.replace(group, name + "={" + name) + # Change LABBETI: replace slash in metrics name by underscore + name_filt = name.replace("/", "_") + # Change LABBETI: SEP char will be "CHECKPOINT_SEP_CHAR" instead of "=" + filename = filename.replace( + group, + name_filt + cls.CHECKPOINT_SEP_CHAR + "{" + name, + ) + + if name not in metrics: + metrics[name] = 0 + filename = filename.format(**metrics) + + if prefix: + filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) + + return filename + + def _save_topk_checkpoint( + self, + trainer: Trainer, + monitor_candidates: dict[str, Any], + ) -> None: + if self.monitor is None or self.save_top_k == 0: + return + + current = monitor_candidates.get(self.monitor) + epoch = monitor_candidates.get("epoch", -1) + step = monitor_candidates.get("step", -1) + + if self.save_after_epoch is None: + min_epoch = -1 + elif isinstance(self.save_after_epoch, int): + min_epoch = self.save_after_epoch + elif isinstance(self.save_after_epoch, float): + if trainer.max_epochs is None: + raise RuntimeError( + f"Cannot use float {self.save_after_epoch=} with {trainer.max_epochs=}." + ) + min_epoch = math.floor(self.save_after_epoch * trainer.max_epochs) + else: + raise TypeError( + f"Invalid argument {self.save_after_epoch=}. (expected None, int or float)" + ) + + if self.check_monitor_top_k(trainer, current) and ( + epoch is None or epoch >= min_epoch + ): + self._update_best_and_save(current, trainer, monitor_candidates) # type: ignore + + # Track best epoch and best step + self._best_monitor_candidates = monitor_candidates + + # Log current monitor value + if self.log_best_score and self.best_model_score is not None: + monitor_best_name = f"{self.monitor}_{self.mode}" + trainer.lightning_module.log( + monitor_best_name, + self.best_model_score, + on_epoch=True, + on_step=False, + sync_dist=not trainer.move_metrics_to_cpu, # type: ignore + ) + self._best_monitor_candidates[monitor_best_name] = self.best_model_score + + elif self.verbose: + message_prefix = f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:.2e}," + + # Change LABBETI: Modify info message + if epoch is None or epoch >= min_epoch: + current_best = self._best_monitor_candidates.get(self.monitor, None) + current_best = ( + f"{current_best:.2e}" if current_best is not None else "None" + ) + message_suffix = ( + f"but was not in top {self.save_top_k} (best {current_best})" + ) + else: + message_suffix = f"but found {epoch=} < {min_epoch}" + + rank_zero_info(f"{message_prefix} {message_suffix}") + + def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + if self.dirpath is not None: + os.makedirs(str(self.dirpath), exist_ok=True) + self.to_yaml() + + if ( + not self.create_symlink + or not trainer.is_global_zero + or not osp.isfile(self.best_model_path) + ): + return None + + ckpt_dpath = osp.dirname(self.best_model_path) + ckpt_fname = osp.basename(self.best_model_path) + lpath = osp.join(ckpt_dpath, "best.ckpt") + + if osp.exists(lpath): + pylog.warning(f"Link {osp.basename(lpath)} already exists.") + return None + + os.symlink(ckpt_fname, lpath) + + if not osp.isfile(lpath): + pylog.error(f"Invalid symlink file {lpath=}.") + elif self.verbose: + pylog.debug( + f"Create relative symlink for best model checkpoint '{lpath}'. (from='{self.best_model_path}')" + ) + + def get_best_monitor_candidates(self) -> dict[str, Any]: + return self._best_monitor_candidates diff --git a/src/conette/callbacks/debug.py b/src/conette/callbacks/debug.py new file mode 100644 index 000000000..78c20c098 --- /dev/null +++ b/src/conette/callbacks/debug.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +import torch + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks.callback import Callback + +from conette.utils.csum import csum_module + + +pylog = logging.getLogger(__name__) + + +class PrintDebug(Callback): + def __init__(self, verbose: int = 2) -> None: + super().__init__() + self.verbose = verbose + + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_fit_start", self.verbose) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_train_start", self.verbose) + + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_validation_start", self.verbose) + + def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_test_start", self.verbose) + + def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_fit_end", self.verbose) + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_train_end", self.verbose) + + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_validation_end", self.verbose) + + def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + _print_csum(pl_module, "on_test_end", self.verbose) + + +def _print_csum(pl_module: LightningModule, fn_name: str, verbose: int) -> None: + if verbose < 2: + return None + + with torch.inference_mode(): + training = pl_module.training + pl_module.train(False) + csum = csum_module(pl_module) + pl_module.train(training) + + pylog.debug( + f"Model checksum for '{fn_name}': {csum} ({len(list(pl_module.named_parameters()))} tensors)" + ) diff --git a/src/conette/callbacks/deepspeed.py b/src/conette/callbacks/deepspeed.py new file mode 100644 index 000000000..0cbe27364 --- /dev/null +++ b/src/conette/callbacks/deepspeed.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Any + +from deepspeed.profiling.flops_profiler import get_model_profile +from pytorch_lightning import LightningModule +from pytorch_lightning.callbacks.callback import Callback +from torch import Tensor + +from conette.nn.functional.misc import move_to_rec +from conette.utils.csum import csum_any + + +pylog = logging.getLogger(__name__) + + +class DeepSpeedCallback(Callback): + def __init__(self, single_input: bool = False, verbose: int = 0) -> None: + super().__init__() + self._single_input = single_input + self._verbose = verbose + self._metrics = {} + + def state_dict(self) -> dict[str, Any]: + return self._metrics + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._metrics |= state_dict + + def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + return self.profile(pl_module) + + def profile(self, pl_module: LightningModule) -> None: + example = pl_module.example_input_array + + if self._verbose >= 2: + csum = csum_any(example) + pylog.debug(f"Batch example csum: {csum}") + + if isinstance(example, dict): + # Assuming that arguments are in the correct order + + if self._single_input: + batch = example["batch"] + single_batch = {} + + for k, v in batch.items(): + if isinstance(v, Tensor): + v = v[0][None] + elif isinstance(v, (list,)): + v = [v[0]] + elif isinstance(v, (tuple,)): + v = (v[0],) + else: + raise TypeError( + f"Invalid item in batch. (found {type(v)} with {v=})" + ) + single_batch[k] = v + + example["batch"] = single_batch + bsize = 1 + + else: + audio = example.get("batch", {}).get("audio") + if audio is None: + bsize = -1 + else: + bsize = len(audio) + + example = move_to_rec(example, device=pl_module.device) + outputs: tuple[int, int, int] = get_model_profile( # type: ignore + pl_module, + kwargs=example, + print_profile=self._verbose >= 2, + as_string=False, + ) + flops, macs, params = outputs + + if bsize != -1: + flops_per_sample = flops / bsize + macs_per_sample = macs / bsize + else: + flops_per_sample = -1 + macs_per_sample = -1 + + if self._verbose >= 1: + pylog.info("According to deepspeed, model has:") + pylog.info(f"- {params} parameters") + + if flops_per_sample == -1: + pylog.info(f"- {flops} FLOPs (with unknown bsize)") + else: + pylog.info( + f"- {flops_per_sample} FLOPs (based on {flops=} with {bsize=})" + ) + + if macs_per_sample == -1: + pylog.info(f"- {macs} MACs (with unknown bsize)") + else: + pylog.info( + f"- {macs_per_sample} MACs (based on {macs=} with {bsize=})" + ) + + metrics = { + "other/dspeed_flops": flops, + "other/dspeed_macs": macs, + "other/dspeed_params": params, + "other/dspeed_bsize": bsize, + "other/dspeed_flops_per_sample": flops_per_sample, + "other/dpseed_macs_per_sample": macs_per_sample, + } + else: + metrics = {} + if self._verbose >= 0: + pylog.warning( + f"Unsupported example type {type(example)}. (expected dict)" + ) + return None + + self._metrics |= metrics + for pl_logger in pl_module.loggers: + pl_logger.log_hyperparams({}, metrics=self._metrics) + + def get_metrics(self) -> dict[str, Any]: + return self._metrics diff --git a/src/conette/callbacks/log.py b/src/conette/callbacks/log.py new file mode 100644 index 000000000..f47ae57e7 --- /dev/null +++ b/src/conette/callbacks/log.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import gc + +from typing import Optional + +import torch + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks.callback import Callback +from torch.optim import Optimizer +from torch.random import get_rng_state + + +class LogLRCallback(Callback): + """Log the learning rate (lr) in the pylog and each iteration.""" + + def __init__( + self, + prefix: str = "train/", + on_epoch: bool = False, + on_step: bool = True, + bsize: Optional[int] = None, + ) -> None: + super().__init__() + self.prefix = prefix + self.on_epoch = on_epoch + self.on_step = on_step + self.bsize = bsize + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + *args, + **kwargs, + ) -> None: + optimizers = pl_module.optimizers() + + if isinstance(optimizers, Optimizer): + optimizers = [optimizers] + elif isinstance(optimizers, (tuple, list)): + pass + else: + raise TypeError( + f"Unsupported optimizers type {type(optimizers)}. (expected Optimizer, tuple[Optimizer, ...] or list[Optimizer])" + ) + + for i, optimizer in enumerate(optimizers): + if not isinstance(optimizer, Optimizer): + raise TypeError( + f"Unsupported optimizers type {type(optimizers)}. (expected Optimizer)" + ) + + learning_rates = [ + param_group["lr"] for param_group in optimizer.param_groups + ] + + for j, lr in enumerate(learning_rates): + if len(optimizers) == 1: + if len(learning_rates) == 1: + name = f"{self.prefix}lr" + else: + name = f"{self.prefix}lr{j}" + else: + if len(learning_rates) == 1: + name = f"{self.prefix}optim{i}_lr" + else: + name = f"{self.prefix}optim{i}_lr{j}" + + pl_module.log( + name, + lr, + on_epoch=self.on_epoch, + on_step=self.on_step, + batch_size=self.bsize, + sync_dist=not trainer.move_metrics_to_cpu, # type: ignore + ) + + +class LogGCCallback(Callback): + def __init__(self, prefix: str = "train/", bsize: Optional[int] = None) -> None: + super().__init__() + self.prefix = prefix + self.bsize = bsize + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + *args, + **kwargs, + ) -> None: + counts = gc.get_count() + thresholds = gc.get_threshold() + + for i, (count, threshold) in enumerate(zip(counts, thresholds)): + name = f"{self.prefix}debug_gc_gen{i}" + prop = count / threshold + pl_module.log( + name, + prop, + on_epoch=False, + on_step=True, + batch_size=self.bsize, + sync_dist=not trainer.move_metrics_to_cpu, # type: ignore + ) + + +class LogGradNorm(Callback): + def __init__( + self, + name: str = "train/grad_norm2", + p_norm: int = 2, + bsize: Optional[int] = None, + on_epoch: bool = True, + on_step: bool = False, + ) -> None: + super().__init__() + self.name = name + self.p_norm = p_norm + self.bsize = bsize + self.on_epoch = on_epoch + self.on_step = on_step + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + *args, + **kwargs, + ) -> None: + parameters = [ + param.grad.norm(p=self.p_norm) # type: ignore + for param in pl_module.parameters() + if param.grad is not None + ] + grad_norm = torch.as_tensor(parameters, dtype=torch.float64).sum() + + pl_module.log( + self.name, + grad_norm, + on_epoch=self.on_epoch, + on_step=self.on_step, + batch_size=self.bsize, + sync_dist=not trainer.move_metrics_to_cpu, # type: ignore + ) + + +class LogRngState(Callback): + def __init__(self, prefix: str = "train/", bsize: Optional[int] = None) -> None: + super().__init__() + self.prefix = prefix + self.bsize = bsize + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + *args, + **kwargs, + ) -> None: + rng_state = get_rng_state().sum().float() + pl_module.log( + f"{self.prefix}rng_state", + rng_state, + on_epoch=True, + on_step=False, + batch_size=self.bsize, + sync_dist=not trainer.move_metrics_to_cpu, # type: ignore + ) diff --git a/src/conette/callbacks/resume.py b/src/conette/callbacks/resume.py new file mode 100644 index 000000000..d54465eff --- /dev/null +++ b/src/conette/callbacks/resume.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import glob +import logging +import os.path as osp +import re + +from typing import Iterable, Optional, Union + +import torch + +from pytorch_lightning import LightningModule +from pytorch_lightning.callbacks.callback import Callback +from torch import Tensor + +from conette.utils.csum import csum_module + + +pylog = logging.getLogger(__name__) + + +class ResumeCallback(Callback): + def __init__( + self, + resume: Optional[str], + strict: bool = True, + ign_weights: Union[str, Iterable[str]] = (), + use_glob: bool = False, + verbose: int = 1, + ) -> None: + """ + :param pl_resume_path: The path to the checkpoint file containing the weights or to the logdir path containing the weight file in '{pl_resume_path}/checkpoints/best.ckpt'. + :param strict: If True, the loading will crash if all keys weights does not match with the pl_module. defaults to False. + :param verbose: The verbose level. defaults to 1. + """ + super().__init__() + self._resume = resume + self._strict = strict + self._ign_weights = ign_weights + self._use_glob = use_glob + self._verbose = verbose + + self._loaded = False + + def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + self.load_checkpoint(pl_module) + + def on_validation_start(self, trainer, pl_module: LightningModule) -> None: + self.load_checkpoint(pl_module) + + def on_test_start(self, trainer, pl_module: LightningModule) -> None: + self.load_checkpoint(pl_module) + + def load_checkpoint(self, pl_module: LightningModule) -> None: + if self._loaded: + return None + + load_checkpoint( + pl_module=pl_module, + resume=self._resume, + strict=self._strict, + ign_weights=self._ign_weights, + use_glob=self._use_glob, + verbose=self._verbose, + ) + self._loaded = True + + +def load_checkpoint( + pl_module: LightningModule, + resume: Optional[str], + strict: bool = True, + ign_weights: Union[str, Iterable[str]] = (), + use_glob: bool = False, + verbose: int = 0, +) -> None: + if resume is None: + return None + + if isinstance(ign_weights, str): + ign_weights = [ign_weights] + else: + ign_weights = list(ign_weights) + + if use_glob: + matchs = glob.glob(resume) + if len(matchs) == 0: + raise ValueError(f"Cannot find ckpt file with glob pattern '{resume}'.") + elif len(matchs) > 1: + raise ValueError( + f"Found multiple ckpt files with glob pattern '{resume}'. (found {len(matchs)} matchs)" + ) + resume = matchs[0] + + if not isinstance(resume, str) or not osp.exists(resume): + raise ValueError( + f"Invalid resume checkpoint fpath {resume=}. (path does not exists)" + ) + + if osp.isfile(resume): + ckpt_fpath = resume + elif osp.isdir(resume): + ckpt_fpath = osp.join(resume, "checkpoints", "best.ckpt") + if not osp.isfile(ckpt_fpath): + raise FileNotFoundError( + f"Cannot find checkpoint in {resume=} (expected in {{resume}}/checkpoints/best.ckpt)." + ) + else: + raise ValueError(f"Invalid path type {resume=}.") + + if verbose >= 1: + pylog.info(f"Loading pl_module from checkpoint {ckpt_fpath=}.") + pylog.debug(f"pl_module csum before resume weights = {csum_module(pl_module)}") + + # Load best model before training + checkpoint_data = torch.load(ckpt_fpath, map_location=pl_module.device) + state_dict: dict[str, Tensor] = checkpoint_data["state_dict"] + state_dict = { + k: v + for k, v in state_dict.items() + if all((re.match(pattern, k) is None) for pattern in ign_weights) + } + + try: + incompatible_keys = pl_module.load_state_dict(state_dict, strict=strict) + + if verbose >= 2: + pylog.debug(f"Found incompatible keys: {incompatible_keys}") + + except RuntimeError as err: + pylog.error( + f"Cannot load weights from ckpt file '{ckpt_fpath}'. (with strict={strict})" + ) + raise err + + if verbose >= 1: + pylog.debug(f"pl_module csum after resume weights = {csum_module(pl_module)}") diff --git a/src/conette/callbacks/stats_saver.py b/src/conette/callbacks/stats_saver.py new file mode 100644 index 000000000..0618eb369 --- /dev/null +++ b/src/conette/callbacks/stats_saver.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv +import logging +import os +import os.path as osp + +from argparse import Namespace +from typing import Any, Iterable, Optional, Union + +import yaml + +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.callbacks.checkpoint import Checkpoint +from pytorch_lightning.core.saving import save_hparams_to_yaml +from torch import Tensor + +from conette.callbacks.time import TimeTrackerCallback +from conette.info import get_install_info +from conette.nn.functional.misc import count_params +from conette.tokenization.aac_tokenizer import AACTokenizer +from conette.utils.csum import csum_module +from conette.utils.custom_logger import CustomTensorboardLogger +from conette.utils.misc import get_current_git_hash, save_conda_env, save_micromamba_env + + +pylog = logging.getLogger(__name__) + + +class StatsSaver(Callback): + """Callback for saving some stats about the training in the pylog.""" + + def __init__( + self, + subrun_path: Optional[str], + tokenizers: Optional[dict[str, Optional[AACTokenizer]]] = None, + on_end: str = "none", + close_logger_on_end: bool = True, + git_hash: Optional[str] = None, + cfg: Optional[DictConfig] = None, + verbose: int = 1, + ) -> None: + if subrun_path is not None: + subrun_path = osp.expandvars(subrun_path) + + if tokenizers is None: + tokenizers = {} + else: + tokenizers = { + name: tokenizer + for name, tokenizer in tokenizers.items() + if tokenizer is not None + } + + if on_end not in ("fit", "test", "none"): + raise ValueError(f"Invalid argument {on_end=}.") + + if git_hash is None: + git_hash = get_current_git_hash(default=None) + + super().__init__() + self._subrun_dir = subrun_path + self._tokenizers = tokenizers + self._on_end = on_end + self._close_logger_on_end = close_logger_on_end + self._git_hash = git_hash + self._cfg = cfg + self._verbose = verbose + + self._time_tracker = TimeTrackerCallback() + self._start_csum = 0 + self._end_csum = 0 + + # Callback methods + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._time_tracker.on_fit_start(trainer, pl_module) + self._start_csum = csum_module(pl_module) + self._end_csum = self._start_csum + + def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._time_tracker.on_fit_end(trainer, pl_module) + self._end_csum = csum_module(pl_module) + + if self._on_end == "fit": + self.save_metrics_stats(trainer, pl_module) + + def on_train_epoch_start( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + self._time_tracker.on_train_epoch_start(trainer, pl_module) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._time_tracker.on_train_epoch_end(trainer, pl_module) + + def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._time_tracker.on_test_start(trainer, pl_module) + + def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._time_tracker.on_test_end(trainer, pl_module) + if self._on_end == "test": + self.save_metrics_stats(trainer, pl_module) + + # Other methods + def save_metrics_stats( + self, + trainer: Trainer, + pl_module: Optional[LightningModule], + datamodule: Optional[LightningDataModule] = None, + add_params: Optional[dict[str, Any]] = None, + add_metrics: Optional[dict[str, Any]] = None, + ) -> None: + if datamodule is None: + datamodule = trainer.datamodule # type: ignore + ckpts = trainer.checkpoint_callbacks if trainer is not None else None + + save_to_dir( + subrun_path=self._subrun_dir, + datamodule=datamodule, + pl_module=pl_module, + tokenizers=self._tokenizers, # type: ignore + time_tracker=self._time_tracker, + checkpoint=ckpts, + close_logger_on_end=self._close_logger_on_end, + git_hash=self._git_hash, + cfg=self._cfg, + start_csum=self._start_csum, + end_csum=self._end_csum, + verbose=self._verbose, + add_params=add_params, + add_metrics=add_metrics, + ) + + +def save_to_dir( + subrun_path: Optional[str], + datamodule: Optional[LightningDataModule] = None, + pl_module: Optional[LightningModule] = None, + tokenizers: Optional[dict[str, Optional[AACTokenizer]]] = None, + time_tracker: Optional[TimeTrackerCallback] = None, + checkpoint: Union[Checkpoint, Iterable[Checkpoint], None] = None, + close_logger_on_end: bool = True, + git_hash: Optional[str] = None, + cfg: Optional[DictConfig] = None, + start_csum: Optional[int] = None, + end_csum: Optional[int] = None, + verbose: int = 0, + add_slurm_vars_to_params: bool = False, + add_version_info_to_params: bool = False, + add_params: Optional[dict[str, Any]] = None, + add_metrics: Optional[dict[str, Any]] = None, + save_conda: bool = False, + save_micromamba: bool = True, +) -> None: + """Save callbacks and miscellaneous information in subrun_path directory.""" + if subrun_path is None: + return None + + subrun_path = osp.expandvars(subrun_path) + if not osp.isdir(subrun_path): + return None + + if add_params is None: + params = {} + else: + params = add_params + + if add_metrics is None: + other_metrics = {} + else: + other_metrics = add_metrics + + if git_hash is None: + git_hash = get_current_git_hash(default=None) + + params |= { + "git_hash": git_hash, + "start_csum": start_csum, + "end_csum": end_csum, + } + + if add_slurm_vars_to_params: + params |= { + key.lower(): value + for key, value in os.environ.items() + if key.startswith("SLURM_") + } + + if add_version_info_to_params: + versions = get_install_info() + versions = {f"{name}_version": version for name, version in versions.items()} + params |= versions + + hp_dpath = osp.join(subrun_path, "hparams") + os.makedirs(hp_dpath, exist_ok=True) + + # Note: do not use save_hparams_to_yaml for os.environ to avoid interpolation errors + with open(osp.join(hp_dpath, "os_env.yaml"), "w") as file: + yaml.dump(dict(os.environ), file, sort_keys=False) + + if save_conda: + if cfg is not None: + conda_path = cfg.get("path", {}).get("conda", "conda") + else: + conda_path = "conda" + save_conda_env(osp.join(hp_dpath, "conda_env.yaml"), conda_path) + + if save_micromamba: + if cfg is not None: + micromamba_path = cfg.get("path", {}).get("micromamba", "micromamba") + else: + micromamba_path = "micromamba" + save_micromamba_env(osp.join(hp_dpath, "micromamba_env.yaml"), micromamba_path) + + hydra_cfg = HydraConfig.get() + hydra_cfg = {"hydra": hydra_cfg} + save_hparams_to_yaml(osp.join(hp_dpath, "resolved_hydra.yaml"), hydra_cfg) + + if cfg is not None: + save_hparams_to_yaml(osp.join(hp_dpath, "resolved_config.yaml"), cfg) # type: ignore + + if pl_module is not None: + save_hparams_to_yaml( + osp.join(hp_dpath, "pl_module.yaml"), + pl_module.hparams_initial, + ) + other_metrics |= { + "total_params": count_params(pl_module, only_trainable=False), + "train_params": count_params(pl_module, only_trainable=True), + } + + if datamodule is not None: + save_hparams_to_yaml( + osp.join(hp_dpath, "datamodule.yaml"), + datamodule.hparams_initial, + ) + + if time_tracker is not None: + params |= { + "fit_duration": time_tracker.get_fit_duration_formatted(), + "test_duration": time_tracker.get_test_duration_formatted(), + } + other_metrics |= { + "fit_duration_h": time_tracker.get_fit_duration_in_hours(), + "test_duration_h": time_tracker.get_test_duration_in_hours(), + "epoch_mean_duration_min": time_tracker.get_epoch_mean_duration_in_min(), + } + + if checkpoint is None: + ckpts = [] + elif not isinstance(checkpoint, Iterable): + ckpts = [checkpoint] + else: + ckpts = list(checkpoint) + del checkpoint + + for ckpt in ckpts: + if not all( + hasattr(ckpt, attr) for attr in ("get_best_monitor_candidates", "monitor") + ): + pylog.warning( + f"Cannot save best epoch values for checkpoint type {ckpt.__class__.__name__}." + ) + continue + + best_monitor_candidates: dict[str, Any] = ckpt.get_best_monitor_candidates() # type: ignore + + monitor = ckpt.monitor # type: ignore + # note : no need to handle case where / is not found because: + # example: s = "abcabc"; s.rfind("d") gives -1, so s[s.rfind("d")+1:] == s[0:] == s + monitor = monitor[monitor.rfind("/") + 1 :] + + best_monitor_candidates = { + f"best_{monitor}_{name}": _clean_value(value) + for name, value in best_monitor_candidates.items() + } + other_metrics |= best_monitor_candidates + + if tokenizers is None: + tokenizers = {} + + for name, tokenizer in tokenizers.items(): + if tokenizer is None: + continue + + # Save tokenizer to pickle file + tokenizer_fname = f"{name}.pickle" + tokenizer_fpath = osp.join(subrun_path, tokenizer_fname) + tokenizer.save_file(tokenizer_fpath) + + # Save tokenizer hparams to yaml file + hparams_fpath = osp.join(hp_dpath, f"{name}.yaml") + hparams = tokenizer.get_hparams() + with open(hparams_fpath, "w") as file: + yaml.dump(hparams, file) + + if tokenizer.is_fit(): + # Save vocabulary to csv file + vocab_fname = f"vocabulary_{name}.csv" + vocab_fpath = osp.join(subrun_path, vocab_fname) + + fieldnames = ("token", "occurrence", "index") + data = [ + { + "token": token, + "occurrence": occurrence, + "index": tokenizer.token_to_id(token), + } + for token, occurrence in tokenizer.get_vocab().items() + ] + + with open(vocab_fpath, "w") as file: + writer = csv.DictWriter(file, fieldnames) + writer.writeheader() + writer.writerows(data) # type: ignore + + other_metrics[f"{name}_vocab_size"] = tokenizer.get_vocab_size() + other_metrics[ + f"{name}_min_sentence_size" + ] = tokenizer.get_min_sentence_size() + other_metrics[ + f"{name}_max_sentence_size" + ] = tokenizer.get_max_sentence_size() + + # Remove optional None values + params = {k: v for k, v in params.items() if v is not None} + other_metrics = {k: v for k, v in other_metrics.items() if v is not None} + + other_metrics = {f"other/{name}": value for name, value in other_metrics.items()} + + if verbose >= 2: + pylog.debug( + f"Adding {len(params)} params :\n{yaml.dump(params, sort_keys=False)}" + ) + pylog.debug( + f"Adding {len(other_metrics)} metrics :\n{yaml.dump(other_metrics, sort_keys=False)}" + ) + + # Store params and metrics + if pl_module is not None: + for pl_logger in pl_module.loggers: + if isinstance(pl_logger, CustomTensorboardLogger): + pl_logger.log_hyperparams(params=params, metrics=other_metrics) + + if close_logger_on_end: + pl_logger.save_and_close() + else: + ns_params = Namespace() + ns_params.__dict__.update(params) + pl_logger.log_hyperparams(ns_params) + pl_logger.log_metrics(other_metrics) + + +def _clean_value(value) -> Any: + if isinstance(value, Tensor): + if value.ndim == 0: + return value.item() + else: + return value.tolist() + else: + return value diff --git a/src/conette/callbacks/time.py b/src/conette/callbacks/time.py new file mode 100644 index 000000000..643e47b5c --- /dev/null +++ b/src/conette/callbacks/time.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import time + +from typing import Optional, Union + +from pytorch_lightning import LightningModule +from pytorch_lightning.callbacks.callback import Callback + + +class TimeTrackerCallback(Callback): + def __init__(self, log_fit_duration: bool = False) -> None: + super().__init__() + self._log_fit_duration = log_fit_duration + + self._fit_start_time = 0.0 + self._fit_end_time = 0.0 + self._test_start_time = 0.0 + self._test_end_time = 0.0 + self._epoch_starts = [] + self._epoch_ends = [] + self._n_fit_ended = 0 + self._n_test_ended = 0 + + def on_fit_start(self, trainer, pl_module) -> None: + self._fit_start_time = time.perf_counter() + + def on_fit_end(self, trainer, pl_module: LightningModule) -> None: + self._fit_end_time = time.perf_counter() + + if self._log_fit_duration: + key = "fit_duration" + ( + "" if self._n_test_ended == 0 else str(self._n_test_ended) + ) + pl_module.log(key, self.get_fit_duration(), on_step=False, on_epoch=True) + self._n_fit_ended += 1 + + def on_train_epoch_start(self, trainer, pl_module) -> None: + self._epoch_starts.append(time.perf_counter()) + + def on_train_epoch_end(self, trainer, pl_module) -> None: + self._epoch_ends.append(time.perf_counter()) + + def on_test_start(self, trainer, pl_module) -> None: + self._test_start_time = time.perf_counter() + + def on_test_end(self, trainer, pl_module) -> None: + self._test_end_time = time.perf_counter() + + if self._log_fit_duration: + key = "test_duration" + ( + "" if self._n_test_ended == 0 else str(self._n_test_ended) + ) + pl_module.log(key, self.get_fit_duration(), on_step=False, on_epoch=True) + self._n_test_ended += 1 + + def get_fit_duration(self) -> float: + """Return the fit duration in seconds.""" + return self._fit_end_time - self._fit_start_time + + def get_test_duration(self) -> float: + """Return the test duration in seconds.""" + return self._test_end_time - self._test_start_time + + def get_fit_duration_in_hours(self) -> float: + return self.get_fit_duration() / 3600.0 + + def get_test_duration_in_hours(self) -> float: + return self.get_test_duration() / 3600.0 + + def get_fit_duration_formatted(self) -> str: + """Return the fit duration as ISO format ddTHH:mm:ss.""" + return format_duration(self.get_fit_duration()) + + def get_test_duration_formatted(self) -> str: + """Return the test duration as ISO format ddTHH:mm:ss.""" + return format_duration(self.get_test_duration()) + + def get_epoch_mean_duration_in_min(self, epoch: Optional[int] = None) -> float: + if len(self._epoch_ends) > 0: + if epoch is None: + if len(self._epoch_starts) == len(self._epoch_ends): + maxidx = None + elif len(self._epoch_starts) - 1 == len(self._epoch_ends): + maxidx = -1 + else: + raise ValueError("Invalid epoch starts list.") + + return ( + (sum(self._epoch_ends) - sum(self._epoch_starts[:maxidx])) + / len(self._epoch_ends) + / 60.0 + ) + else: + return (self._epoch_ends[epoch] - self._epoch_starts[epoch]) / 60.0 + else: + return -1.0 + + +def format_duration( + duration_sec: Union[int, float], + days_hours_sep: str = "_", # "T" + other_sep: str = "-", # ":" + force_days: bool = False, +) -> str: + """Get formatted duration as {dd}_{HH}-{mm}-{ss} or {HH}-{mm}-{ss}""" + duration_sec = int(duration_sec) + rest, seconds = divmod(duration_sec, 60) + rest, minutes = divmod(rest, 60) + if rest > 24 or force_days: + days, hours = divmod(rest, 24) + duration_str = f"{days:02d}{days_hours_sep}{hours:02d}{other_sep}{minutes:02d}{other_sep}{seconds:02d}" + else: + hours = rest + duration_str = f"{hours:02d}{other_sep}{minutes:02d}{other_sep}{seconds:02d}" + return duration_str diff --git a/src/conette/datamodules/__init__.py b/src/conette/datamodules/__init__.py new file mode 100644 index 000000000..58402e1b7 --- /dev/null +++ b/src/conette/datamodules/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""LightningDataModules (datamodules) directory. +""" diff --git a/src/conette/datamodules/aac_dm.py b/src/conette/datamodules/aac_dm.py new file mode 100644 index 000000000..37cbc5750 --- /dev/null +++ b/src/conette/datamodules/aac_dm.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from argparse import Namespace +from typing import Any, Callable, Iterable, Optional + +from pytorch_lightning import LightningDataModule +from torch.utils.data.dataloader import DataLoader + + +pylog = logging.getLogger(__name__) + + +class AACDataModule(LightningDataModule): + DISABLE_TEARDOWN: bool = False + _IGNORE_ARGS: tuple[str, ...] = () + + def __init__( + self, + root: str = "data", + bsize: int = 512, + n_workers: Optional[int] = 0, + pin_memory: bool = True, + train_drop_last: bool = False, + verbose: int = 1, + train_cols: Iterable[str] = (), + val_cols: Iterable[str] = (), + test_cols: Iterable[str] = (), + ) -> None: + super().__init__() + self._setup_fit_done = False + self._setup_test_done = False + self._setup_predict_done = False + + self._train_dset: Any = None + self._val_dset: Any = None + self._test_dsets: dict[str, Any] = {} + self._predict_dsets: dict[str, Any] = {} + + self._train_collate: Optional[Callable] = None + self._val_collate: Optional[Callable] = None + self._test_collate: Optional[Callable] = None + self._predict_collate: Optional[Callable] = None + + self.save_hyperparameters(ignore=self._IGNORE_ARGS) + + # Abstract methods + def _setup_fit(self) -> None: + raise NotImplementedError("Abstract method") + + def _setup_test(self) -> None: + raise NotImplementedError("Abstract method") + + def _setup_predict(self) -> None: + raise NotImplementedError("Abstract method") + + # LightningDataModule methods + def prepare_data(self) -> None: + pass + + def setup(self, stage: Optional[str] = None) -> None: + if stage in ("fit", "validate", None) and not self._setup_fit_done: + if self.hp.verbose >= 1: + pylog.info("Starting fit setup...") + + self._setup_fit() + self._setup_fit_done = True + + if self.hp.verbose >= 1: + dsets = {"train": self._train_dset, "val": self._val_dset} + dsets = {name: ds for name, ds in dsets.items() if ds is not None} + sizes = {name: len(ds) for name, ds in dsets.items()} + pylog.info(f"Setup for train is done with {sizes}.") + + if stage in ("test", None) and not self._setup_test_done: + if self.hp.verbose >= 1: + pylog.info("Starting test setup...") + + self._setup_test() + self._setup_test_done = True + + if self.hp.verbose >= 1: + dsets = self._test_dsets + dsets = { + name: ds for name, ds in self._test_dsets.items() if ds is not None + } + sizes = {name: len(ds) for name, ds in dsets.items()} + pylog.info(f"Setup for test is done with {sizes}.") + + if stage in ("predict", None) and not self._setup_predict_done: + if self.hp.verbose >= 1: + pylog.info("Starting predict setup...") + + self._setup_predict() + self._setup_predict_done = True + + if self.hp.verbose >= 1: + sizes = {name: len(dset) for name, dset in self._predict_dsets.items()} + pylog.info(f"Setup for predict is done with {sizes}.") + + def teardown(self, stage: Optional[str] = None) -> None: + if self.DISABLE_TEARDOWN: + if self.hp.verbose >= 0: + pylog.warning( + f"Teardown has been called with {stage=}, but it has been disabled with global var DISABLE_TEARDOWN." + ) + return None + + if self.hp.verbose >= 2: + pylog.debug(f"Teardown stage {stage}.") + + # note: do not teardown when stage=="validate" to avoid re-build fit datasets twice + if stage in ("fit", None): + self._train_dset: Any = None + self._val_dset: Any = None + self._setup_fit_done = False + + if stage in ("test", None): + self._test_dsets = {} + self._predict_dsets = {} + self._setup_test_done = False + + if stage in ("predict", None): + self._predict_dsets = {} + self._setup_predict_done = False + + def train_dataloader(self) -> DataLoader: + if self.hp.verbose >= 2: + pylog.debug("Build train dataloader(s)...") + + return DataLoader( + dataset=self._train_dset, + batch_size=self.hp.bsize, + num_workers=self.hp.n_workers, + shuffle=True, + collate_fn=self._train_collate, + pin_memory=self.hp.pin_memory, + drop_last=self.hp.train_drop_last, + sampler=None, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + dataset=self._val_dset, + batch_size=self.hp.bsize, + num_workers=self.hp.n_workers, + shuffle=False, + collate_fn=self._val_collate, + pin_memory=self.hp.pin_memory, + drop_last=False, + ) + + def test_dataloader(self) -> list[DataLoader]: + return [ + DataLoader( + dataset=dset, + batch_size=self.hp.bsize, + num_workers=self.hp.n_workers, + shuffle=False, + collate_fn=self._test_collate, + pin_memory=self.hp.pin_memory, + drop_last=False, + ) + for dset in self._test_dsets.values() + ] + + def predict_dataloader(self) -> list[DataLoader]: + return [ + DataLoader( + dataset=dset, + batch_size=self.hp.bsize, + num_workers=self.hp.n_workers, + shuffle=False, + collate_fn=self._predict_collate, + pin_memory=self.hp.pin_memory, + drop_last=False, + ) + for dset in self._predict_dsets.values() + ] + + # Other methods + @property + def hp(self) -> Namespace: + return Namespace(**self.hparams) + + @property + def hp_init(self) -> Namespace: + return Namespace(**self.hparams_initial) + + @property + def root(self) -> str: + return self.hparams["root"] + + @property + def bsize(self) -> int: + return self.hparams["bsize"] + + @property + def n_workers(self) -> int: + return self.hparams["n_workers"] + + @property + def pin_memory(self) -> bool: + return self.hparams["pin_memory"] + + @property + def verbose(self) -> int: + return self.hparams["verbose"] + + @property + def train_cols(self) -> list[str]: + return self.hparams["train_cols"] + + @property + def val_cols(self) -> list[str]: + return self.hparams["val_cols"] + + @property + def test_cols(self) -> list[str]: + return self.hparams["test_cols"] diff --git a/src/conette/datamodules/collate.py b/src/conette/datamodules/collate.py new file mode 100644 index 000000000..24bd49e1a --- /dev/null +++ b/src/conette/datamodules/collate.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Any, Iterable, Optional + +import torch + +from torch import Tensor + +from conette.datasets.hdf.common import SHAPE_SUFFIX +from conette.nn.functional.misc import can_be_stacked +from conette.nn.functional.pad import pad_sequence_rec +from conette.utils.collections import all_eq + + +pylog = logging.getLogger(__name__) + + +class CollateDict: + """Collate list of dict into a dict of list WITHOUT auto-padding.""" + + def __call__(self, items_lst: list[dict[str, Any]]) -> dict[str, list[Any]]: + common_keys = items_lst[0].keys() + for i in range(1, len(items_lst)): + common_keys = [key for key in common_keys if key in items_lst[i].keys()] + return {key: [item[key] for item in items_lst] for key in common_keys} + + +class AdvancedCollateDict: + def __init__( + self, + pad_values: Optional[dict[str, Any]] = None, + crop_keys: Iterable[str] = (), + batch_keys: Optional[Iterable[str]] = None, + ) -> None: + """Collate list of dict into a dict of list WITH auto-padding for given keys. + + :param pad_values: The dictionnary of key with pad value. + :param crop_keys: Depreciated crop keys. + :param batch_keys: The expected batch keys. + """ + + if pad_values is None: + pad_values = {} + crop_keys = list(dict.fromkeys(crop_keys)) + if batch_keys is not None: + batch_keys = list(batch_keys) + + super().__init__() + self._pad_values = pad_values + self._crop_keys = crop_keys + self._batch_keys = batch_keys + + def __call__(self, batch_lst: list[dict[str, Any]]) -> dict[str, Any]: + if self._batch_keys is None: + # Intersection of keys and keep the same order + batch_keys = list(batch_lst[0].keys()) + for item in batch_lst[1:]: + batch_keys = [key for key in batch_keys if key in item.keys()] + else: + batch_keys = self._batch_keys + + batch_dic: dict[str, Any] = { + key: [item[key] for item in batch_lst] for key in batch_keys + } + batch_dic = { + key: (torch.stack(items) if key.endswith(SHAPE_SUFFIX) else items) + for key, items in batch_dic.items() + } + + for key in batch_keys: + items = batch_dic[key] + key_shape = f"{key}{SHAPE_SUFFIX}" + + if key in self._crop_keys: + shapes = batch_dic[key_shape] + max_shape = shapes.max(dim=0).values + + slices = [slice(shape_i) for shape_i in max_shape] + for i in range(len(items)): + items[i] = items[i][slices] + items = torch.stack(items) + + elif key in self._pad_values.keys(): + if key_shape not in batch_dic.keys(): + try: + shapes = [item.shape for item in items] + except AttributeError as err: + raise err + if not all_eq(map(len, shapes)): + pylog.error( + f"Cannot collate list of tensors with a different number of dims. ({shapes=})" + ) + continue + + shapes = torch.as_tensor(shapes) + batch_dic[key_shape] = shapes + + pad_value = self._pad_values[key] + items = pad_sequence_rec(items, pad_value=pad_value) + + elif ( + not key.endswith(SHAPE_SUFFIX) + and all(isinstance(item, Tensor) for item in items) + and can_be_stacked(items) + ): + items = torch.stack(items) + + batch_dic[key] = items + + return batch_dic + + +def detect_scalar_type(item: Any) -> type: + types = set() + queue = [item] + while len(queue) > 0: + item = queue.pop() + if isinstance(item, (list, tuple)) and len(item) > 0: + queue += item + else: + types.add(type(item)) + + if len(types) == 1: + return list(types)[0] + else: + raise RuntimeError(f"Multiple types detected: {types=}.") + + +def detect_shape(item: Any) -> Tensor: + if isinstance(item, (int, float, str)): + return torch.as_tensor((), dtype=torch.long) + elif isinstance(item, Tensor) and item.ndim in (0, 1): + return torch.as_tensor(item.shape, dtype=torch.long) + elif isinstance(item, (Tensor, list, tuple)): + if len(item) == 0 or isinstance(item[0], (int, float, str)): + return torch.as_tensor((len(item),), dtype=torch.long) + else: + subshapes = [detect_shape(subitem) for subitem in item] + subdims = list(map(len, subshapes)) + if not all_eq(subdims): + pylog.error(f"Function detech_shape: found {subshapes=}") + raise RuntimeError( + f"Invalid number of dims with {subdims=} in function 'detect_shape'." + ) + return torch.stack([torch.as_tensor(subshape) for subshape in subshapes]) + else: + raise RuntimeError(f"Unknown subtype {item.__class__.__name__}.") diff --git a/src/conette/datamodules/common.py b/src/conette/datamodules/common.py new file mode 100644 index 000000000..afc317472 --- /dev/null +++ b/src/conette/datamodules/common.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import math +import random +import re +import os +import os.path as osp + +from collections import Counter +from typing import Any, Callable, Iterable, Optional, Union + +import torch +import yaml + +from nltk.util import ngrams +from torch import Generator, Tensor +from torch.utils.data.dataset import ConcatDataset + +from aac_datasets.datasets.audiocaps import AudioCaps + +from conette.tokenization.aac_tokenizer import AACTokenizer +from conette.datasets.hdf import HDFDataset +from conette.datasets.typing import SizedDatasetLike +from conette.datasets.utils import ( + TransformWrapper, + ZipDataset, +) + + +pylog = logging.getLogger(__name__) + + +def get_hdf_fpaths( + dataname: str, + subsets: Iterable[str], + hdf_root: str, + hdf_suffix: Optional[str], + hdf_dname: str = "HDF", +) -> dict[str, str]: + """Returns the dictionary of HDF datasets filepaths for each subset : + ``` + { + {subset_1}: {hdf_root}/{hdf_dname}/{dataname}_{subset_1}_{hdf_suffix}.hdf + {subset_2}: {hdf_root}/{hdf_dname}/{dataname}_{subset_2}_{hdf_suffix}.hdf + ... + } + ``` + If hdf_suffix is None, returns an empty dict. + """ + if hdf_suffix is None: + return {} + + dataname = dataname.lower() + subsets = list(map(str.lower, subsets)) + pattern = re.compile( + r"(?P[a-z]+)_(?P[a-z]+)_(?P.+)\.hdf" + ) + hdf_root = osp.expandvars(hdf_root) + + if not osp.isdir(osp.join(hdf_root, hdf_dname)): + raise FileNotFoundError(f"Cannot find {hdf_dname} directory in {hdf_root=}.") + + hdf_fpaths = {} + + for subset in subsets: + hdf_fname = f"{dataname}_{subset}_{hdf_suffix}.hdf" + hdf_fpath = osp.join(hdf_root, hdf_dname, hdf_fname) + + if not osp.isfile(hdf_fpath): + names = os.listdir(osp.join(hdf_root, hdf_dname)) + matches = [re.match(pattern, name) for name in names] + availables_hdf_suffix = [ + match["hdf_suffix"] + for match in matches + if match is not None + and match["dataname"] == dataname + and match["subset"] == subset + ] + + pylog.error( + f"Cannot find HDF file '{hdf_fpath}' with {hdf_suffix=}.\n" + f"Maybe run conette-prepare before and use another hdf_suffix for {dataname}.\n" + f"Available hdf_suffix for '{dataname}_{subset}' are:\n{yaml.dump(availables_hdf_suffix, sort_keys=False)}" + ) + hdf_fpaths[subset] = hdf_fpath + + return hdf_fpaths + + +class PreEncodedCaptionsTransform: + def __init__( + self, + audio_tfm: Optional[Callable], + ref_selection: Union[str, int, slice], + add_raw_refs: bool, + mult_captions: Union[list, Tensor], + mrefs_src_key: str = "captions", + ) -> None: + super().__init__() + self.audio_tfm = audio_tfm + self.ref_selection = ref_selection + self.mult_captions = mult_captions + self.add_raw_refs = add_raw_refs + self.mrefs_src_key = mrefs_src_key + + def __call__(self, item: dict[str, Any]) -> dict[str, Any]: + item_idx = item["index"] + captions = self.mult_captions[item_idx] + references = item[self.mrefs_src_key] + + if self.audio_tfm is not None: + item["audio"] = self.audio_tfm(item["audio"]) + + if isinstance(self.ref_selection, str): + if self.ref_selection == "random": + idxs = random.randint(0, len(captions) - 1) + else: + raise ValueError(f"Invalid argument {self.ref_selection=}.") + else: + idxs = self.ref_selection + + if isinstance(idxs, int): + caption = captions[idxs] + reference = references[idxs] + + item["captions"] = caption + if self.add_raw_refs: + item["references"] = reference + + elif idxs == slice(None): + item.pop("captions") + item["mult_captions"] = captions + if self.add_raw_refs: + item["mult_references"] = references + + else: + raise ValueError(f"Invalid argument {idxs=} with {self.ref_selection=}.") + + return item + + +class OnlineEncodeCaptionsTransform: + def __init__( + self, + audio_tfm: Optional[Callable[[Tensor], Tensor]], + ref_selection: Union[str, int, slice], + add_raw_refs: bool, + tokenizer: AACTokenizer, + encode_kwargs: dict[str, Any], + mrefs_src_key: Optional[str] = "captions", + audio_time_dim: int = -2, + ref_tfm: Optional[Callable[[str], str]] = None, + ) -> None: + super().__init__() + self.audio_tfm = audio_tfm + self.ref_selection = ref_selection + self.add_raw_refs = add_raw_refs + self.tokenizer = tokenizer + self.encode_kwargs = encode_kwargs + self.mrefs_src_key = mrefs_src_key + self.audio_time_dim = audio_time_dim + self.ref_tfm = ref_tfm + + def __call__(self, item: dict[str, Any]) -> dict[str, Any]: + if self.audio_tfm is not None: + audio = item["audio"] + audio_shape = item["audio_shape"] + audio_len = audio_shape[self.audio_time_dim] + if audio_len < audio.shape[self.audio_time_dim]: + mask = [slice(None) for _ in range(audio.ndim)] + mask[self.audio_time_dim] = slice(audio_len) + audio[mask] = self.audio_tfm(audio[mask]) + else: + audio = self.audio_tfm(audio) + item["audio"] = audio.contiguous() + + if self.mrefs_src_key is not None: + refs = item[self.mrefs_src_key] + + if isinstance(self.ref_selection, str): + if self.ref_selection == "random": + idxs = random.randint(0, len(refs) - 1) + else: + raise ValueError(f"Invalid argument {self.ref_selection=}.") + else: + idxs = self.ref_selection + + if isinstance(idxs, int): + ref = refs[idxs] + if self.ref_tfm is not None: + ref = self.ref_tfm(ref) + + cap = self.tokenizer.encode_single( + ref, + **self.encode_kwargs, + ) + + item["captions"] = cap + if self.add_raw_refs: + item["references"] = ref + + elif idxs == slice(None): + item.pop("captions") + + if self.ref_tfm is not None: + refs = [self.ref_tfm(ref) for ref in refs] + + mcaps = self.tokenizer.encode_batch( + refs, + **self.encode_kwargs, + ) + item["mult_captions"] = mcaps + + if self.add_raw_refs: + item["mult_references"] = refs + + else: + raise ValueError( + f"Invalid argument {idxs=} with {self.ref_selection=}." + ) + + return item + + +class OnlineEncodeCaptionsTransformWithEmbs: + def __init__( + self, + audio_tfm: Optional[Callable], + ref_selection: Union[str, int, slice], + add_raw_refs: bool, + tokenizer: AACTokenizer, + encode_kwargs: dict[str, Any], + mrefs_src_key: str = "captions", + mrefs_embs_src_key: str = "captions_embs", + ) -> None: + super().__init__() + self.audio_tfm = audio_tfm + self.ref_selection = ref_selection + self.add_raw_refs = add_raw_refs + self.tokenizer = tokenizer + self.encode_kwargs = encode_kwargs + self.mrefs_src_key = mrefs_src_key + self.mrefs_embs_src_key = mrefs_embs_src_key + + def __call__(self, item: dict[str, Any]) -> dict[str, Any]: + references = item[self.mrefs_src_key] + references_embs = item[self.mrefs_embs_src_key] + + if self.audio_tfm is not None: + item["audio"] = self.audio_tfm(item["audio"]) + + if isinstance(self.ref_selection, str): + if self.ref_selection == "random": + idxs = random.randint(0, len(references) - 1) + else: + raise ValueError(f"Invalid argument {self.ref_selection=}.") + else: + idxs = self.ref_selection + + if isinstance(idxs, int): + reference = references[idxs] + caption = self.tokenizer.encode_single( + reference, + **self.encode_kwargs, + ) + + item["captions"] = caption + item[self.mrefs_embs_src_key] = references_embs[idxs] + if self.add_raw_refs: + item["references"] = reference + + elif idxs == slice(None): + item.pop("captions") + mult_captions = self.tokenizer.encode_batch( + references, + **self.encode_kwargs, + ) + + item["mult_captions"] = mult_captions + item[f"mult_{self.mrefs_embs_src_key}"] = references_embs[idxs] + if self.add_raw_refs: + item["mult_references"] = references + + else: + raise ValueError(f"Invalid argument {idxs=} with {self.ref_selection=}.") + + return item + + +def split_indexes( + indexes: Iterable[int], + ratios: Iterable[float], +) -> list[list[int]]: + assert 0 <= sum(ratios) <= 1.0 + 1e-20, f"Found {sum(ratios)=} not in [0, 1]." + indexes = list(indexes) + ratio_cumsum = 0.0 + outs = [] + for ratio in ratios: + start = math.floor(ratio_cumsum * len(indexes)) + end = math.floor((ratio_cumsum + ratio) * len(indexes)) + sub_indexes = indexes[start:end] + outs.append(sub_indexes) + ratio_cumsum += ratio + return outs + + +def generate_random_split( + size: int, ratios: Iterable[float], seed: Union[int, None, Generator] +) -> list[list[int]]: + if isinstance(seed, int): + generator = Generator().manual_seed(seed) + else: + generator = seed + + indexes = torch.randperm(size, generator=generator).tolist() + splitted_indexes = split_indexes(indexes, ratios) + return splitted_indexes + + +def get_auto_num_cpus() -> int: + return len(os.sched_getaffinity(0)) + + +def auto_n_workers(n_workers: Optional[int]) -> int: + if n_workers is None: + num_cpus = get_auto_num_cpus() + return num_cpus + else: + return n_workers + + +def build_mult_task_train_dataset( + train_dset: SizedDatasetLike, + train_tokenizer: AACTokenizer, + root: str, + audio_padding: str, + task_hdfs: list[str], + task_tag_types: list[str], + idx_to_name_dicts: list[dict[int, str]], + task_data_add: str, +) -> SizedDatasetLike: + def tfm_task_0(item: dict) -> dict: + item["task"] = torch.as_tensor(0) + return item + + def get_tfm_task_1( + task_tag_type: str, + idx_to_name: dict[int, str], + ) -> Callable: + assert task_tag_type in ("audioset", "fsd50k") + + def tfm_task_1(item: dict) -> dict: + tags = item["tags"].tolist() + + indexes = torch.randperm(len(tags)) + tags = [tags[idx] for idx in indexes] + joined_tags_names = ", ".join(idx_to_name[tag] for tag in tags) + encoded_tags = train_tokenizer.encode_single( + joined_tags_names, default=False + ) + item["captions"] = encoded_tags + item["task"] = torch.as_tensor(1) + + return item + + return tfm_task_1 + + train_dset_task_0 = TransformWrapper(train_dset, tfm_task_0) + + train_dsets = [train_dset_task_0] + keep_padding = ("audio",) if audio_padding in ("crop", "longest") else () + + for task_hdf, task_tag_type, idx_to_name in zip(task_hdfs, task_tag_types, idx_to_name_dicts): # type: ignore + hdf_fpath = osp.join(root, "HDF", task_hdf) # type: ignore + hdf_dset = HDFDataset( + hdf_fpath, + get_tfm_task_1(task_tag_type, idx_to_name), + keep_padding=keep_padding, + ) + train_dsets.append(hdf_dset) # type: ignore + + if task_data_add == "cat": + train_dset = ConcatDataset(train_dsets) + elif task_data_add == "cat_lim": + # train_dsets = [train_dset_task_0] + [ + # WrapperSampler(dset, len(train_dset_task_0)) for dset in train_dsets[1:] + # ] + raise NotImplementedError + elif task_data_add == "zip_max": + train_dset = ZipDataset(*train_dsets, mode="max") + elif task_data_add == "zip_min": + train_dset = ZipDataset(*train_dsets, mode="min") + else: + raise ValueError(f"Invalid argument {task_data_add=}.") + + return train_dset + + +def get_counter(sents: list[list[str]], nmax: int) -> Counter[tuple[str, ...]]: + assert nmax > 0 + counter = Counter() + for n in range(1, nmax + 1): + for sent in sents: + for ngram in ngrams(sent, n): + assert len(ngram) == n + counter[ngram] += 1 + return counter + + +def build_caps_complexity_scores( + sents: list[list[str]], nmax: int, mode: str = "ign_ngram_sup" +) -> Tensor: + assert mode in ("ign_ngram_sup", "zero_ngram_sup") + counter = dict(get_counter(sents, nmax)) + count_per_ngram = [ + sum(count for ngram, count in counter.items() if len(ngram) == n) + for n in range(1, nmax + 1) + ] + scores = torch.stack( + [ + torch.as_tensor( + [ + ( + ( + torch.as_tensor( + [ + count_per_ngram[n - 1] / counter[ngram] + for ngram in ngrams(sent, n) + ] + ).mean() + ) + if mode != "zero_ngram_sup" or len(sent) >= n + else 0.0 + ) + for n in range(1, nmax + 1) + if mode != "ign_ngram_sup" or len(sent) >= n + ] + ).mean() + for sent in sents + ] + ) + max_score = scores.max().item() + scores = scores / max_score + scores = scores.numpy() + return scores + + +def _get_unscaled_score( + tok_sent: list[str], + counter: dict[tuple[str, ...], int], + count_per_ngram: list[int], + nmax: int, + mode: str, +) -> Tensor: + ngram_scores = [] + for n in range(1, nmax + 1): + if len(tok_sent) < n: + if mode == "ign_ngram_sup": + continue + elif mode == "zero_ngram_sup": + ngram_scores.append(0.0) + else: + raise ValueError(f"Invalid argument {mode=}.") + else: + ngram_score = torch.as_tensor( + [ + count_per_ngram[n - 1] / counter[ngram] + for ngram in ngrams(tok_sent, n) + ] + ).mean() + ngram_scores.append(ngram_score) + score = torch.as_tensor(ngram_scores).mean() + return score + + +class CapsComplexity: + def __init__(self, nmax: int, mode: str = "ign_ngram_sup") -> None: + super().__init__() + self._nmax = nmax + self._mode = mode + + self._counter = {} + self._count_per_ngram = [] + self._max_score = 1.0 + + def fit(self, tok_sents: list[list[str]]) -> Tensor: + counter = dict(get_counter(tok_sents, self._nmax)) + count_per_ngram = [ + sum(count for ngram, count in counter.items() if len(ngram) == n) + for n in range(1, self._nmax + 1) + ] + + unscaled_scores = [ + _get_unscaled_score(sent, counter, count_per_ngram, self._nmax, self._mode) + for sent in tok_sents + ] + unscaled_scores = torch.as_tensor(unscaled_scores) + max_score = unscaled_scores.max().item() + + self._counter = counter + self._count_per_ngram = count_per_ngram + self._max_score = max_score + + return self.get_scores(tok_sents) + + def get_scores(self, tok_sents: list[list[str]]) -> Tensor: + scores = torch.as_tensor([self.get_score(tok_sent) for tok_sent in tok_sents]) + return scores + + def get_score(self, tok_sent: list[str]) -> Tensor: + score = _get_unscaled_score( + tok_sent, self._counter, self._count_per_ngram, self._nmax, self._mode + ) + score = score / self._max_score + return score diff --git a/src/conette/datamodules/hdf.py b/src/conette/datamodules/hdf.py new file mode 100644 index 000000000..219555ed1 --- /dev/null +++ b/src/conette/datamodules/hdf.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os.path as osp + +from typing import Iterable, Optional, Union + +import tqdm + +from torch import nn +from torch.utils.data.dataloader import DataLoader + +from conette.datamodules.aac_dm import AACDataModule +from conette.datamodules.collate import AdvancedCollateDict +from conette.datamodules.common import ( + OnlineEncodeCaptionsTransform, + get_auto_num_cpus, +) +from conette.datasets.hdf import HDFDataset +from conette.datasets.utils import ( + AACConcat, + AACDuplicate, + AACSelectColumnsWrapper, + TransformWrapper, + WrapperSampler, +) +from conette.tokenization.aac_tokenizer import AACTokenizer +from conette.utils.csum import csum_any + + +pylog = logging.getLogger(__name__) + +DEFAULT_TRAIN_COLS = ("audio", "audio_shape", "captions") +DEFAULT_VAL_COLS = ("audio", "audio_shape", "captions") +DEFAULT_TEST_COLS = ( + "audio", + "audio_shape", + "captions", + "dataset", + "subset", + "fname", + "index", +) + + +class HDFDataModule(AACDataModule): + TUNE_MODE = False + _IGNORE_ARGS = ( + "train_audio_tfm", + "val_audio_tfm", + "test_audio_tfm", + "train_tokenizer", + ) + AUDIO_PADDINGS = ("batch", "longest", "crop") + + def __init__( + self, + # AACDataModule params + root: str = "data", + bsize: int = 512, + n_workers: Optional[int] = 0, + pin_memory: bool = True, + train_drop_last: bool = False, + verbose: int = 1, + train_cols: Iterable[str] = DEFAULT_TRAIN_COLS, + val_cols: Iterable[str] = DEFAULT_VAL_COLS, + test_cols: Iterable[str] = DEFAULT_TEST_COLS, + train_audio_tfm: Optional[nn.Module] = None, + val_audio_tfm: Optional[nn.Module] = None, + test_audio_tfm: Optional[nn.Module] = None, + train_tokenizer: Optional[AACTokenizer] = None, + # Other params + train_hdfs: Union[str, Iterable[str]] = (), + val_hdfs: Union[str, Iterable[str]] = (), + test_hdfs: Union[str, Iterable[str]] = (), + predict_hdfs: Union[str, Iterable[str]] = (), + audio_padding: str = "batch", + main_hdf_duplicate: Optional[str] = None, + main_hdf_min: Optional[str] = None, + main_hdf_balanced: Optional[Iterable[str]] = None, + n_added_data: Optional[int] = None, + ) -> None: + """Initialize the AudioCaps datamodule for building dataloaders. + + :param root: The dataset parent directory. defaults to "data". + :param bsize: The batch size of the dataloaders. defaults to 512. + :param n_workers: The number of workers of the dataloaders. defaults to 0. + :param pin_memory: If True, the dataloaders will pin memory of tensors. defaults to True. + :param verbose: Verbose level. defaults to 1. + :param train_cols: The columns to extract from the original HDF dataset source during training. + :param val_cols: The columns to extract from the original HDF dataset source during validation. + :param test_cols: The columns to extract from the original HDF dataset source during testing. + :param train_audio_tfm: The train audio transform to apply to each item. defaults to None. + :param val_audio_tfm: The val audio transform to apply to each item. defaults to None. + :param test_audio_tfm: The test audio transform to apply to each item. defaults to None. + :param train_tokenizer: The AACTokenizer for train captions. None will create a default AACTokenizer. defaults to None. + :param train_hdfs: List of HDF filenames for training. defaults to (). + :param val_hdfs: List of HDF filenames for validation. defaults to (). + :param test_hdfs: List of HDF filenames for testing. defaults to (). + :param predict_hdfs: List of HDF filenames for prediction. defaults to (). + :param audio_padding: Audio batch padding mode. Can be one of ("batch", "crop", "longest"). defaults to "batch". + :param main_hdf_duplicate: Duplicate the main train dataset to have the same length than the sum of other datasets added. defaults to None. + :param main_hdf_min: Reduce other added per epoch to have the same length than the main train dataset. defaults to None. + """ + # Process args + root = osp.expanduser(osp.expandvars(root)) + + if n_workers is None: + n_workers = get_auto_num_cpus() + if verbose >= 1: + pylog.info(f"Found {n_workers} CPU that will be used for DataLoaders.") + + train_cols = list(train_cols) + val_cols = list(val_cols) + test_cols = list(test_cols) + + def process_hdfs_args(hdfs: Union[str, Iterable[str]]) -> list[str]: + if isinstance(hdfs, str): + return [hdfs] + else: + return list(hdfs) + + train_hdfs = process_hdfs_args(train_hdfs) + val_hdfs = process_hdfs_args(val_hdfs) + test_hdfs = process_hdfs_args(test_hdfs) + predict_hdfs = process_hdfs_args(predict_hdfs) + + if train_tokenizer is None: + train_tokenizer = AACTokenizer() + + # Check args + if main_hdf_duplicate is not None and main_hdf_min is not None: + raise ValueError( + f"Cannot use arguments {main_hdf_duplicate=} and {main_hdf_min=} at the same time." + ) + if main_hdf_duplicate is not None and main_hdf_duplicate not in train_hdfs: + raise ValueError( + f"Invalid argument {main_hdf_duplicate=}. (expected one of train hdf files {train_hdfs})" + ) + + if main_hdf_min is not None and main_hdf_min not in train_hdfs: + raise ValueError( + f"Invalid argument {main_hdf_min=}. (expected one of train hdf files {train_hdfs})" + ) + + if audio_padding not in self.AUDIO_PADDINGS: + raise ValueError( + f"Invalid argument {audio_padding=}. (expected one of {self.AUDIO_PADDINGS})" + ) + + if ( + main_hdf_min is None + and main_hdf_balanced is None + and n_added_data is not None + ): + raise ValueError( + f"Invalid argument {n_added_data=} with {main_hdf_min=} and {main_hdf_balanced=}." + ) + + if main_hdf_balanced is not None and not all( + hdf_name in train_hdfs for hdf_name in main_hdf_balanced + ): + raise ValueError(f"Invalid argument {main_hdf_balanced=}.") + + super().__init__( + root=root, + bsize=bsize, + n_workers=n_workers, + pin_memory=pin_memory, + train_drop_last=train_drop_last, + verbose=verbose, + train_cols=train_cols, + val_cols=val_cols, + test_cols=test_cols, + ) + self._train_audio_tfm = train_audio_tfm + self._val_audio_tfm = val_audio_tfm + self._test_audio_tfm = test_audio_tfm + self._train_tokenizer = train_tokenizer + + self._wrapper_samplers: list[WrapperSampler] = [] + + def train_dataloader(self) -> DataLoader: + for wrapper_sampler in self._wrapper_samplers: + prev_csum = csum_any(wrapper_sampler.indexes) + wrapper_sampler.reset_indexes() + if self.hp.verbose >= 2: + csum = csum_any(wrapper_sampler.indexes) + pylog.debug(f"Indexes has been shuffled. ({prev_csum}, {csum})") + return super().train_dataloader() + + # Other methods + def _setup_fit(self) -> None: + keep_padding = ( + ("audio",) if self.hp.audio_padding in ("crop", "longest") else () + ) + train_dsets_lst = [ + HDFDataset( + osp.join(self.hp.root, "HDF", fname), + keep_padding=keep_padding, + ) + for fname in self.hp.train_hdfs + ] + val_dsets_lst = [ + HDFDataset( + osp.join(self.hp.root, "HDF", fname), + keep_padding=keep_padding, + ) + for fname in self.hp.val_hdfs + ] + if self.hp.verbose >= 2: + pylog.debug( + f"HDF datasets loaded. (train={len(train_dsets_lst)}, val={len(val_dsets_lst)})" + ) + + train_dsets_lst = [ + AACSelectColumnsWrapper(dset, include=self.hp.train_cols) + for dset in train_dsets_lst + ] + val_dsets_lst = [ + AACSelectColumnsWrapper(dset, include=self.hp.val_cols) + for dset in val_dsets_lst + ] + + train_mrefs: list[list[str]] = [ + refs + for train_dset_i in tqdm.tqdm( + train_dsets_lst, + disable=self.hp.verbose < 1, + desc="Loading captions for build id-to-token mappings...", + ) + for refs in train_dset_i.at(None, "captions") + ] + + if self.hp.main_hdf_duplicate is not None: + tgt_idx = self.hp.train_hdfs.index(self.hp.main_hdf_duplicate) + tgt_dset = train_dsets_lst[tgt_idx] + other_sum = sum( + len(dset) for i, dset in enumerate(train_dsets_lst) if i != tgt_idx + ) + + if self.hp.verbose >= 1: + pylog.info( + f"Duplicate dataset {self.hp.main_hdf_duplicate} from {len(tgt_dset)} to {other_sum}." + ) + + if len(tgt_dset) < other_sum: + train_dsets_lst[tgt_idx] = AACDuplicate(tgt_dset, other_sum) # type: ignore + + elif self.hp.main_hdf_min is not None: + tgt_idx = self.hp.train_hdfs.index(self.hp.main_hdf_min) + tgt_dset = train_dsets_lst[tgt_idx] + other_dsets = [ + dset for i, dset in enumerate(train_dsets_lst) if i != tgt_idx + ] + other_dsets = AACConcat(*other_dsets) + + if self.hp.n_added_data is not None: + n_added_data = self.hp.n_added_data + else: + n_added_data = len(tgt_dset) + + if self.hp.verbose >= 1: + pylog.info( + f"Minimize others datasets from {len(other_dsets)} to {n_added_data}." + ) + + other_dsets = WrapperSampler(other_dsets, n_added_data) + self._wrapper_samplers = [other_dsets] + train_dsets_lst = [tgt_dset, other_dsets] + + elif self.hp.main_hdf_balanced is not None: + train_hdf_fnames: list[str] = list(self.hp.train_hdfs) + main_hdf_balanced: list[str] = list(self.hp.main_hdf_balanced) + + tgt_idxs = [ + train_hdf_fnames.index(hdf_name) for hdf_name in main_hdf_balanced + ] + tgt_dsets = [train_dsets_lst[tgt_idx] for tgt_idx in tgt_idxs] + other_dsets = [ + dset for i, dset in enumerate(train_dsets_lst) if i not in tgt_idxs + ] + other_dsets = AACConcat(*other_dsets) + + max_ds_size = max(map(len, tgt_dsets + [other_dsets])) + + train_dsets_lst = [] + wrapper_samplers = [] + + if self.hp.n_added_data is not None: + n_added_data = self.hp.n_added_data + else: + n_added_data = max_ds_size + del max_ds_size + + if self.hp.verbose >= 1: + pylog.info( + f"Minimize others datasets from {len(other_dsets)} to {n_added_data}." + ) + + for tgt_ds in tgt_dsets + [other_dsets]: + if len(tgt_ds) == n_added_data: + train_dsets_lst.append(tgt_ds) + elif len(tgt_ds) < n_added_data: + train_dsets_lst.append(AACDuplicate(tgt_ds, n_added_data)) + else: # > + wrapped = WrapperSampler(tgt_ds, n_added_data) + train_dsets_lst.append(wrapped) + wrapper_samplers.append(wrapped) + + self._wrapper_samplers = wrapper_samplers + + else: + if self.hp.verbose >= 1: + pylog.info("No change applied to added datasets.") + + if len(train_dsets_lst) == 1: + train_dset = train_dsets_lst[0] + else: + train_dset = AACConcat(*train_dsets_lst) + + if len(val_dsets_lst) == 1: + val_dset = val_dsets_lst[0] + else: + val_dset = AACConcat(*val_dsets_lst) + + del train_dsets_lst, val_dsets_lst + + if not self._train_tokenizer.is_fit(): + train_mrefs_flat = [ref for refs in train_mrefs for ref in refs] + self._train_tokenizer.fit(train_mrefs_flat) + + train_tfm = OnlineEncodeCaptionsTransform( + self._train_audio_tfm, + "random", + False, + self._train_tokenizer, + dict(add_bos_eos=True, default=None, padding=None), + ) + val_tfm = OnlineEncodeCaptionsTransform( + self._val_audio_tfm, + slice(None), + True, + self._train_tokenizer, + dict( + add_bos_eos=True, + default=self._train_tokenizer.unk_token, + padding="batch", + ), + ) + + train_dset = TransformWrapper(train_dset, train_tfm) + val_dset = TransformWrapper(val_dset, val_tfm) + + self._train_dset = train_dset + self._val_dset = val_dset + + pad_values = { + "captions": self._train_tokenizer.pad_token_id, + "mult_captions": self._train_tokenizer.pad_token_id, + } + if self.hp.audio_padding == "batch": + pad_values["audio"] = 0.0 # type: ignore + + crop_keys = ("audio",) if self.hp.audio_padding == "crop" else () + self._train_collate = AdvancedCollateDict(pad_values, crop_keys) + self._val_collate = AdvancedCollateDict(pad_values, crop_keys) + + if self.hp.verbose >= 1: + vocab_size = self._train_tokenizer.get_vocab_size() + pylog.info(f"Train dataset size: {len(train_dset)}") + pylog.info(f"Validation dataset size: {len(val_dset)}") + pylog.info(f"Vocabulary size: {vocab_size}") + + def _setup_test(self) -> None: + keep_padding = ( + ("audio",) if self.hp.audio_padding in ("crop", "longest") else () + ) + dsets = { + fname: HDFDataset( + osp.join(self.hp.root, "HDF", fname), + keep_padding=keep_padding, + ) + for fname in self.hp.test_hdfs + } + test_tfm = OnlineEncodeCaptionsTransform( + self._test_audio_tfm, + slice(None), + True, + self._train_tokenizer, + dict( + add_bos_eos=True, + default=self._train_tokenizer.unk_token, + padding="batch", + ), + ) + + dsets = { + fname: AACSelectColumnsWrapper(dset, include=self.hp.test_cols) + for fname, dset in dsets.items() + } + dsets = { + fname: TransformWrapper(dset, test_tfm) for fname, dset in dsets.items() + } + + self._test_dsets = dsets + + pad_values = { + "captions": self._train_tokenizer.pad_token_id, + "mult_captions": self._train_tokenizer.pad_token_id, + } + if self.hp.audio_padding == "batch": + pad_values["audio"] = 0.0 # type: ignore + + crop_keys = ("audio",) if self.hp.audio_padding == "crop" else () + self._test_collate = AdvancedCollateDict(pad_values, crop_keys) + + def _setup_predict(self) -> None: + keep_padding = ( + ("audio",) if self.hp.audio_padding in ("crop", "longest") else () + ) + dsets = { + fname: HDFDataset( + osp.join(self.hp.root, "HDF", fname), + keep_padding=keep_padding, + ) + for fname in self.hp.predict_hdfs + } + test_tfm = OnlineEncodeCaptionsTransform( + self._test_audio_tfm, + slice(None), + True, + self._train_tokenizer, + dict(), + mrefs_src_key=None, + ) + + dsets = { + fname: AACSelectColumnsWrapper( + dset, include=self.hp.test_cols, exclude=("captions",) + ) + for fname, dset in dsets.items() + } + dsets = { + fname: TransformWrapper(dset, test_tfm) for fname, dset in dsets.items() + } + + self._predict_dsets = dsets + + pad_values = {} + if self.hp.audio_padding == "batch": + pad_values["audio"] = 0.0 # type: ignore + + crop_keys = ("audio",) if self.hp.audio_padding == "crop" else () + self._predict_collate = AdvancedCollateDict(pad_values, crop_keys) diff --git a/src/conette/datasets/__init__.py b/src/conette/datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/conette/datasets/hdf/__init__.py b/src/conette/datasets/hdf/__init__.py new file mode 100644 index 000000000..a4f830976 --- /dev/null +++ b/src/conette/datasets/hdf/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from .common import SHAPE_SUFFIX +from .dataset import HDFDataset +from .pack import pack_to_hdf diff --git a/src/conette/datasets/hdf/common.py b/src/conette/datasets/hdf/common.py new file mode 100644 index 000000000..65e17f925 --- /dev/null +++ b/src/conette/datasets/hdf/common.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Any, Sequence + +import h5py +import torch + +from torch import Tensor + + +# Force this encoding +HDF_ENCODING = "utf-8" +# Type for strings +HDF_STRING_DTYPE = h5py.string_dtype(HDF_ENCODING, None) +# Type for empty lists +HDF_VOID_DTYPE = h5py.opaque_dtype("V1") +# Key suffix to store tensor shapes (because they are padded in hdf file) +SHAPE_SUFFIX = "_shape" + + +def all_eq(seq: Sequence[Any]) -> bool: + """Returns True if all element in list are the same.""" + if len(seq) == 0: + return True + else: + first = seq[0] + return all(first == elt for elt in seq[1:]) + + +def get_inverse_perm(indexes: Tensor, dim: int = -1) -> Tensor: + """Return inverse permutation indexes. + + :param indexes: Original permutation indexes as tensor of shape (..., N). + :param dim: Dimension of indexes. defaults to -1. + :returns: Inverse permutation indexes of shape (..., N). + """ + arange = torch.arange( + indexes.shape[dim], + dtype=indexes.dtype, + device=indexes.device, + ) + arange = arange.expand(*indexes.shape) + indexes_inv = torch.empty_like(indexes) + indexes_inv = indexes_inv.scatter(dim, indexes, arange) + return indexes_inv diff --git a/src/conette/datasets/hdf/dataset.py b/src/conette/datasets/hdf/dataset.py new file mode 100644 index 000000000..912a80d95 --- /dev/null +++ b/src/conette/datasets/hdf/dataset.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os +import os.path as osp + +from typing import Any, Callable, Iterable, Optional, Union, overload + +import h5py +import numpy as np +import pickle +import torch +import yaml + +from h5py import Dataset as HDFRawDataset +from torch import Tensor +from torch.utils.data.dataset import Dataset + +from .common import ( + HDF_ENCODING, + HDF_STRING_DTYPE, + HDF_VOID_DTYPE, + SHAPE_SUFFIX, + all_eq, + get_inverse_perm, +) + + +pylog = logging.getLogger(__name__) + + +class HDFDataset(Dataset): + # Initialization + def __init__( + self, + hdf_fpath: str, + transform: Optional[Callable] = None, + keep_padding: Iterable[str] = (), + open_hdf: bool = True, + ) -> None: + """ + :param hdf_fpath: The path to the HDF file. + :param transforms: The transform to apply values (Tensor). default to None. + :param keep_padding: Keys to keep padding values. defaults to (). + :param open_hdf: If True, open the HDF file at start. defaults to True. + """ + if not osp.isfile(hdf_fpath): + names = os.listdir(osp.dirname(hdf_fpath)) + names = [name for name in names if name.endswith(".hdf")] + names = list(sorted(names)) + raise FileNotFoundError( + f"Cannot find HDF file in path {hdf_fpath=}. Possible HDF files are:\n{yaml.dump(names, sort_keys=False)}" + ) + keep_padding = list(keep_padding) + + super().__init__() + self._hdf_fpath = hdf_fpath + self._transform = transform + self._keep_padding = keep_padding + + self._hdf_file: Any = None + + if open_hdf: + self.open() + + # Properties + @property + def column_names(self) -> list[str]: + """The name of each column of the dataset.""" + return list(self.get_hdf_keys()) + + @property + def shape(self) -> tuple[int, ...]: + """The shape of the Clotho dataset.""" + return len(self), len(self.column_names) + + @property + def info(self) -> dict[str, Any]: + """Return the global dataset info.""" + return eval(self._hdf_file.attrs.get("info", "{}")) + + # Public methods + @overload + def at(self, idx: int) -> dict[str, Any]: + ... + + @overload + def at(self, idx: Union[Iterable[int], slice, None]) -> dict[str, list]: + ... + + @overload + def at(self, idx: Any, column: Any) -> Any: + ... + + def at( + self, + idx: Union[int, Iterable[int], slice, None] = None, + column: Union[str, Iterable[str], None] = None, + raw: bool = False, + ) -> Any: + if not self.is_open(): + raise RuntimeError( + f"Cannot get_raw value with closed HDF file. ({self._hdf_file is not None=} and {bool(self._hdf_file)=})" + ) + + if idx is None: + idx = slice(None) + elif isinstance(idx, Tensor): + idx = idx.tolist() + if column is None: + column = self.column_names + + if not isinstance(column, str) and isinstance(column, Iterable): + return {column_i: self.at(idx, column_i) for column_i in column} + + if column not in self.column_names: + raise ValueError( + f"Invalid argument {column=}. (expected one of {tuple(self.column_names)})" + ) + + if isinstance(idx, slice): + is_mult = True + elif isinstance(idx, Iterable): + if not all(isinstance(idx_i, int) for idx_i in idx): + raise TypeError(f"Invalid argument {idx=}.") + is_mult = True + elif isinstance(idx, int): + if not (-len(self) <= idx < len(self)): + raise IndexError( + f"Invalid argument {idx=}. (expected int in range [{-len(self)}, {len(self)-1}])" + ) + is_mult = False + else: + raise TypeError(f"Invalid argument type {type(idx)=}.") + + hdf_value = self._raw_at(idx, column) + if raw: + return hdf_value + + if is_mult: + hdf_values = hdf_value + else: + hdf_values = [hdf_value] + del hdf_value + + shape_name = f"{column}{SHAPE_SUFFIX}" + must_remove_padding = ( + shape_name in self._hdf_file.keys() and column not in self._keep_padding + ) + hdf_ds: HDFRawDataset = self._hdf_file[column] + hdf_dtype = hdf_ds.dtype + + if must_remove_padding: + shapes = self._raw_at(idx, shape_name) + if not is_mult: + shapes = [shapes] + slices_lst = [ + tuple(slice(shape_i) for shape_i in shape) for shape in shapes + ] + else: + slices_lst = [None] * int(hdf_ds.shape[0]) + + outputs = [] + + for hdf_value, slices in zip(hdf_values, slices_lst): + # Remove the padding part + if slices is not None: + hdf_value = hdf_value[slices] + + # Decode all bytes to string + if hdf_dtype == HDF_STRING_DTYPE: + hdf_value = _decode_rec(hdf_value, HDF_ENCODING) + # Convert numpy.array to torch.Tensor + elif isinstance(hdf_value, np.ndarray): + if hdf_dtype != HDF_VOID_DTYPE: + hdf_value = torch.from_numpy(hdf_value) + else: + hdf_value = hdf_value.tolist() + # Convert numpy scalars + elif np.isscalar(hdf_value) and hasattr(hdf_value, "item"): + hdf_value = hdf_value.item() # type: ignore + + outputs.append(hdf_value) + + if not is_mult: + outputs = outputs[0] + return outputs + + def close(self) -> None: + if not self.is_open(): + raise RuntimeError("Cannot close the HDF file twice.") + self._hdf_file.close() + self._hdf_file = None + + def get_attrs(self) -> dict[str, Any]: + return self._hdf_file.attrs + + def get_hdf_fpath(self) -> str: + return self._hdf_fpath + + def get_hdf_keys(self) -> tuple[str, ...]: + if self.is_open(): + return tuple(self._hdf_file.keys()) + else: + raise RuntimeError("Cannot get keys from a closed HDF file.") + + def get_column_shape(self, column_name: str) -> tuple[int, ...]: + if not self.is_open(): + raise RuntimeError( + f"Cannot get max_shape with a closed HDF file. ({self._hdf_file is not None=} and {bool(self._hdf_file)=})" + ) + return tuple(self._hdf_file[column_name].shape) + + def is_open(self) -> bool: + return self._hdf_file is not None and bool(self._hdf_file) + + def open(self) -> None: + if self.is_open(): + raise RuntimeError("Cannot open the HDF file twice.") + self._hdf_file = h5py.File(self._hdf_fpath, "r") + self._sanity_check() + + # Magic methods + def __eq__(self, __o: object) -> bool: + return isinstance(__o, HDFDataset) and pickle.dumps(self) == pickle.dumps(__o) + + def __exit__(self) -> None: + if self.is_open(): + self.close() + + @overload + def __getitem__(self, idx: int) -> dict[str, Any]: + ... + + @overload + def __getitem__(self, idx: Union[Iterable[int], slice, None]) -> dict[str, list]: + ... + + @overload + def __getitem__(self, idx: Any) -> Any: + ... + + def __getitem__( + self, + idx: Union[int, Iterable[int], None, slice, tuple[Any, Any]], + ) -> Any: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + + item = self.at(idx, column) # type: ignore + if isinstance(idx, int) and column is None and self._transform is not None: + item = self._transform(item) + return item + + def __getstate__(self) -> dict[str, Any]: + return { + "hdf_fpath": self._hdf_fpath, + "transform": self._transform, + "keep_padding": self._keep_padding, + } + + def __hash__(self) -> int: + hash_value = 0 + if self.is_open(): + hash_value += self._hdf_file.attrs["global_hash_value"] + if self._transform is not None: + hash_value += hash(self._transform) + hash_value += sum(map(hash, self._keep_padding)) + return hash_value + + def __len__(self) -> int: + return self._hdf_file.attrs["length"] + + def __repr__(self) -> str: + return ( + f"HDFDataset(size={len(self)}, hdf_fname={osp.basename(self._hdf_fpath)})" + ) + + def __setstate__(self, data: dict[str, Any]) -> None: + is_init = hasattr(self, "_hdf_fpath") and hasattr(self, "_hdf_file") + files_are_different = is_init and self._hdf_fpath != data["hdf_fpath"] + is_open = is_init and self.is_open() + + if is_init and files_are_different and is_open: + self.close() + + self._hdf_fpath = data["hdf_fpath"] + self._transform = data["transform"] + self._keep_padding = data["keep_padding"] + self._hdf_file = None + + if not is_init or (files_are_different and is_open): + self.open() + + # Private methods + def _sanity_check(self) -> None: + lens = [dset.shape[0] for dset in self._hdf_file.values()] + if not all_eq(lens) or lens[0] != len(self): + pylog.error( + f"Incorrect length stored in HDF file. (found {lens=} and {len(self)=})" + ) + + def _raw_at(self, idx: Union[int, Iterable[int], slice], column: str) -> Any: + if isinstance(idx, Iterable): + sorted_idxs, local_idxs = torch.as_tensor(idx).sort(dim=-1) + sorted_idxs = sorted_idxs.numpy() + hdf_value: Any = self._hdf_file[column][sorted_idxs] + inv_local_idxs = get_inverse_perm(local_idxs) + hdf_value = [hdf_value[local_idx] for local_idx in inv_local_idxs] + else: + hdf_value: Any = self._hdf_file[column][idx] + return hdf_value + + +def _decode_rec(value: Union[bytes, list], encoding: str) -> Union[str, list]: + """Decode bytes to str with the specified encoding. Works recursively on list of str, list of list of str, etc.""" + if isinstance(value, bytes): + return value.decode(encoding=encoding) + elif isinstance(value, Iterable): + return [_decode_rec(elt, encoding) for elt in value] + else: + raise TypeError( + f"Invalid argument type {type(value)}. (expected bytes or Iterable)" + ) diff --git a/src/conette/datasets/hdf/pack.py b/src/conette/datasets/hdf/pack.py new file mode 100644 index 000000000..b13afc169 --- /dev/null +++ b/src/conette/datasets/hdf/pack.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import datetime +import json +import logging +import os +import os.path as osp +import zlib + +from typing import Any, Callable, Mapping, Optional, Sized, Union + +import h5py +import numpy as np +import torch +import tqdm + +from torch import nn, Tensor +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset + +from .common import ( + HDF_ENCODING, + HDF_STRING_DTYPE, + HDF_VOID_DTYPE, + SHAPE_SUFFIX, + all_eq, +) +from .dataset import HDFDataset + + +pylog = logging.getLogger(__name__) + + +class Compose: + def __init__(self, *fns: Callable) -> None: + super().__init__() + self.fns = fns + + def __call__(self, x: Any) -> Any: + for fn in self.fns: + x = fn(x) + return x + + +def _checksum_rec( + value: Any, +) -> int: + if isinstance(value, bytes): + return zlib.adler32(value) + elif isinstance(value, (np.ndarray, Tensor)): + return int(value.sum().item()) + elif isinstance(value, (int, float)): + return int(value) + elif isinstance(value, str): + return _checksum_rec(value.encode()) + elif isinstance(value, (list, tuple)): + return sum(map(_checksum_rec, value)) + else: + raise TypeError(f"Invalid argument type {value.__class__.__name__}.") + + +def _flat_subdicts(dic: dict[str, Any]) -> dict[str, Any]: + out = {} + for k, v in dic.items(): + if isinstance(v, dict): + for kv, vv in v.items(): + if kv == "": + out[k] = vv + else: + out[f"{k}_{kv}"] = vv + else: + out[k] = v + return out + + +def _get_shape_and_dtype( + value: Union[int, float, str, Tensor, list] +) -> tuple[tuple[int, ...], str]: + """Returns the shape and the hdf_dtype for an input.""" + if isinstance(value, int): + shape = () + hdf_dtype = "i" + + elif isinstance(value, float): + shape = () + hdf_dtype = "f" + + elif isinstance(value, str): + shape = () + hdf_dtype = HDF_STRING_DTYPE + + elif isinstance(value, Tensor): + shape = tuple(value.shape) + if value.is_floating_point(): + hdf_dtype = "f" + else: + hdf_dtype = "i" + + elif isinstance(value, (list, tuple)): + if len(value) == 0: + shape = (0,) + hdf_dtype = HDF_VOID_DTYPE + else: + sub_shapes_and_dtypes = list(map(_get_shape_and_dtype, value)) + sub_shapes = [shape for shape, _ in sub_shapes_and_dtypes] + sub_dtypes = [dtype for _, dtype in sub_shapes_and_dtypes] + sub_dims = list(map(len, sub_shapes)) + + if not all_eq(sub_dims): + raise TypeError( + f"Unsupported list of heterogeneous shapes lengths. (found {sub_dims=})" + ) + if not all_eq(sub_dtypes): + raise TypeError( + f"Unsupported list of heterogeneous types. (found {sub_dtypes=})" + ) + # Check for avoid ragged array like [["a", "b"], ["c"], ["d", "e"]] + if not all_eq(sub_shapes): + raise TypeError( + f"Unsupported list of heterogeneous shapes. (found {sub_shapes=} for {value=})" + ) + + max_subshape = tuple( + max(shape[i] for shape in sub_shapes) for i in range(len(sub_shapes[0])) + ) + shape = (len(value),) + max_subshape + hdf_dtype = sub_dtypes[0] + else: + raise TypeError( + f"Unsupported type {value.__class__.__name__} in function get_shape_and_dtype." + ) + + return shape, hdf_dtype + + +@torch.inference_mode() +def pack_to_hdf( + dataset: Any, + hdf_fpath: str, + pre_save_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + overwrite: bool = False, + metadata: str = "", + verbose: int = 0, + loader_bsize: int = 8, + loader_n_workers: Optional[int] = None, +) -> "HDFDataset": + """ + Pack a dataset to HDF file. + + :param dataset: The sized dataset to pack. Must be sized and all items must be of dict type. + The key of each dictionaries are strings and values can be int, float, str, Tensor, non-empty list[int], non-empty list[float], non-empty list[str]. + If values are tensors or lists, the number of dimensions must be the same for all items in the dataset. + :param hdf_fpath: The path to the HDF file. + :param pre_save_transform: The optional transform to apply to audio returned by the dataset BEFORE storing it in HDF file. + Can be used for deterministic transforms like Resample, LogMelSpectrogram, etc. defaults to None. + :param overwrite: If True, the file hdf_fpath can be overwritten. defaults to False. + :param metadata: Additional metadata string to add to the hdf file. defaults to ''. + :param verbose: Verbose level. defaults to 0. + :param loader_bsize: The batch size of the dataloader. defaults to 8. + :param loader_n_workers: The number of workers of the dataloader. If None, it will be set to `len(os.sched_getaffinity(0))`. defaults to None. + :returns: The HDFDataset object created and but NOT OPENED. + """ + # Check inputs + if not isinstance(dataset, Dataset): + raise TypeError( + f"Cannot pack a non-dataset '{dataset.__class__.__name__}'. (found {isinstance(dataset, Dataset)=})" + ) + if not isinstance(dataset, Sized): + raise TypeError( + f"Cannot pack a non-sized dataset '{dataset.__class__.__name__}'. (found {isinstance(dataset, Sized)=})" + ) + if osp.exists(hdf_fpath) and not osp.isfile(hdf_fpath): + raise RuntimeError(f"Item {hdf_fpath=} exists but it is not a file.") + if not overwrite and osp.isfile(hdf_fpath): + raise ValueError( + f"Cannot overwrite file {hdf_fpath}. Please remove it or use overwrite=True option." + ) + + if pre_save_transform is None: + pre_save_transform = _flat_subdicts + else: + pre_save_transform = Compose(pre_save_transform, _flat_subdicts) + + if loader_n_workers is None: + loader_n_workers = len(os.sched_getaffinity(0)) + if verbose >= 2: + pylog.debug(f"Found loader_n_workers is None, set to {loader_n_workers}.") + + if verbose >= 2: + pylog.debug(f"Start packing data into HDF file '{hdf_fpath}'...") + + # Step 1: Init max_shapes and hdf_dtypes with the first item + item_0 = dataset[0] + if not isinstance(item_0, dict): + raise ValueError( + f"Invalid item type for {dataset.__class__.__name__}. (expected dict but found {type(item_0)})" + ) + + shapes_0 = {} + hdf_dtypes_0 = {} + item_0 = pre_save_transform(item_0) + + for attr_name, value in item_0.items(): + shape, hdf_dtype = _get_shape_and_dtype(value) + shapes_0[attr_name] = shape + hdf_dtypes_0[attr_name] = hdf_dtype + + max_shapes: dict[str, tuple[int, ...]] = shapes_0 + hdf_dtypes: dict[str, str] = hdf_dtypes_0 + + loader = DataLoader( + dataset, + batch_size=loader_bsize, + shuffle=False, + num_workers=loader_n_workers, + collate_fn=nn.Identity(), + drop_last=False, + pin_memory=False, + ) + + for batch in tqdm.tqdm( + loader, + desc="Pre compute shapes...", + disable=verbose <= 0, + ): + batch = [pre_save_transform(item) for item in batch] + for item in batch: + for attr_name, value in item.items(): + shape, hdf_dtype = _get_shape_and_dtype(value) + max_shapes[attr_name] = tuple( + map(max, zip(max_shapes[attr_name], shape)) + ) + if hdf_dtypes[attr_name] == hdf_dtype or hdf_dtype == HDF_VOID_DTYPE: + # Note: HDF_VOID_DTYPE is compatible + pass + elif hdf_dtypes[attr_name] == HDF_VOID_DTYPE: + # Note: if the element 0 was void dtype, override with more specific dtype + hdf_dtypes[attr_name] = hdf_dtype + else: + raise ValueError( + f"Found different hdf_dtype. (with {hdf_dtypes[attr_name]=} != {hdf_dtype=} and {attr_name=} with {value=})" + ) + + if verbose >= 2: + pylog.debug(f"Found max_shapes:\n{max_shapes}") + pylog.debug(f"Found hdf_dtypes:\n{hdf_dtypes}") + + now = datetime.datetime.now() + creation_date = now.strftime("%Y-%m-%d_%H-%M-%S") + + if hasattr(dataset, "info") and isinstance(dataset.info, Mapping): # type: ignore + info = dict(dataset.info) # type: ignore + else: + info = {} + + with h5py.File(hdf_fpath, "w") as hdf_file: + # Step 2: Build hdf datasets in file + hdf_dsets = {} + for attr_name, shape in max_shapes.items(): + hdf_dtype = hdf_dtypes.get(attr_name) + + kwargs: dict[str, Any] = {} + if hdf_dtype == "i": + kwargs["fillvalue"] = 0 + elif hdf_dtype == "f": + kwargs["fillvalue"] = 0.0 + elif hdf_dtype in (HDF_STRING_DTYPE, HDF_VOID_DTYPE): + pass + else: + raise ValueError( + f"Unknown value {hdf_dtype=}. (with {attr_name=} and {attr_name in hdf_dtypes=})" + ) + + if verbose >= 2: + pylog.debug( + f"Build hdf dset '{attr_name}' with shape={(len(dataset),) + shape}." + ) + + hdf_dsets[attr_name] = hdf_file.create_dataset( + attr_name, + (len(dataset),) + shape, + hdf_dtype, + **kwargs, + ) + + if len(shape) > 0: + shape_name = f"{attr_name}{SHAPE_SUFFIX}" + hdf_dsets[shape_name] = hdf_file.create_dataset( + shape_name, (len(dataset), len(shape)), "i", fillvalue=-1 + ) + + # Fill hdf datasets with a second pass through the whole dataset + i = 0 + global_hash_value = 0 + + loader = DataLoader( + dataset, + batch_size=loader_bsize, + shuffle=False, + num_workers=loader_n_workers, + collate_fn=nn.Identity(), + drop_last=False, + pin_memory=False, + ) + + for batch in tqdm.tqdm( + loader, + desc="Pack data into HDF...", + disable=verbose <= 0, + ): + batch = [pre_save_transform(item) for item in batch] + + for item in batch: + for attr_name, value in item.items(): + hdf_dset = hdf_dsets[attr_name] + shape, hdf_dtype = _get_shape_and_dtype(value) + + # Check every shape + if len(shape) != hdf_dset.ndim - 1: + raise ValueError( + f"Invalid number of dimension in audio (expected {len(shape)}, found {len(shape)})." + ) + + # Resize dataset if needed + if any( + shape_i > dset_shape_i + for shape_i, dset_shape_i in zip(shape, hdf_dset.shape[1:]) + ): + pylog.error( + f"Resize hdf_dset {attr_name} of shape {tuple(hdf_dset.shape[1:])} with new {shape=}." + ) + raise RuntimeError( + "INTERNAL ERROR: Cannot resize dataset when pre-computing shapes." + ) + + if isinstance(value, Tensor) and value.is_cuda: + value = value.cpu() + + # If the value is a sequence but not an array or tensor + if hdf_dtype in ("i", "f") and not isinstance( + value, (Tensor, np.ndarray) + ): + value = np.array(value) + + # Note: "dset_audios[slices]" is a generic version of "dset_audios[i, :shape_0, :shape_1]" + slices = (i,) + tuple(slice(shape_i) for shape_i in shape) + try: + hdf_dset[slices] = value + except TypeError as err: + pylog.error( + f"Cannot set data {value} into {hdf_dset[slices].shape} ({attr_name=}, {i=}, {slices=})" + ) + raise err + + # Store original shape if needed + shape_name = f"{attr_name}{SHAPE_SUFFIX}" + if shape_name in hdf_dsets.keys(): + hdf_shapes_dset = hdf_dsets[shape_name] + hdf_shapes_dset[i] = shape + + global_hash_value += _checksum_rec(value) + + i += 1 + + # note: HDF cannot save too large int values with too many bits + global_hash_value = global_hash_value % (2**31) + + attributes = { + "creation_date": creation_date, + "source_dataset": dataset.__class__.__name__, + "length": len(dataset), + "metadata": str(metadata), + "author": "Etienne Labbé (Labbeti)", + "author_mail": "labbeti.pub@gmail.com", + "encoding": HDF_ENCODING, + "info": str(info), + "global_hash_value": global_hash_value, + } + if verbose >= 2: + dumped_attributes = json.dumps(attributes, indent="\t") + pylog.debug(f"Saving attributes in HDF file:\n{dumped_attributes}") + + for attr_name, attr_val in attributes.items(): + try: + hdf_file.attrs[attr_name] = attr_val + except TypeError as err: + pylog.error( + f"Cannot store attribute {attr_name=} with value {attr_val=} in HDF." + ) + raise err + + if verbose >= 2: + pylog.debug(f"Data into has been packed into HDF file '{hdf_fpath}'.") + + hdf_dataset = HDFDataset(hdf_fpath, open_hdf=False) + return hdf_dataset diff --git a/src/conette/datasets/typing.py b/src/conette/datasets/typing.py new file mode 100644 index 000000000..7ba70a9b1 --- /dev/null +++ b/src/conette/datasets/typing.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import ( + Any, + Protocol, + runtime_checkable, +) + + +@runtime_checkable +class DatasetLike(Protocol): + def __getitem__(self, idx: int) -> Any: + raise NotImplementedError("Protocal abstract method.") + + +@runtime_checkable +class SizedDatasetLike(Protocol): + def __getitem__(self, idx: int) -> Any: + raise NotImplementedError("Protocal abstract method.") + + def __len__(self) -> int: + raise NotImplementedError("Protocal abstract method.") + + +@runtime_checkable +class AACDatasetLike(Protocol): + """Protocal abstract class for aac datasets. Used only for typing. + + Methods signatures: + - column_names: () -> list[str] + - at: (int, str) -> Any + - __getitem__: (int, str) -> Any + - __len__: () -> int + """ + + @property + def column_names(self) -> list[str]: + raise NotImplementedError("Protocal abstract method.") + + def at(self, idx: Any, column: Any) -> Any: + raise NotImplementedError("Protocal abstract method.") + + def __getitem__(self, idx: Any) -> Any: + raise NotImplementedError("Protocal abstract method.") + + def __len__(self) -> int: + raise NotImplementedError("Protocal abstract method.") diff --git a/src/conette/datasets/utils.py b/src/conette/datasets/utils.py new file mode 100644 index 000000000..5e10bda1b --- /dev/null +++ b/src/conette/datasets/utils.py @@ -0,0 +1,1111 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import math +import os.path as osp +import time + +from functools import cache +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + MutableMapping, + Optional, + Sequence, + Sized, + TypeVar, + Union, +) + +import torch +import torchaudio +import tqdm + +from torch import Tensor +from torch.utils.data.dataset import Dataset +from torchaudio.backend.common import AudioMetaData + +from conette.datasets.typing import AACDatasetLike, SizedDatasetLike +from conette.utils.disk_cache import disk_cache +from conette.utils.log_utils import warn_once +from conette.utils.misc import pass_filter + + +pylog = logging.getLogger(__name__) +T = TypeVar("T") + + +def _process_idx( + idx: Union[int, Iterable[int], Tensor, slice, None], +) -> Union[int, list[int], slice]: + if isinstance(idx, (int, slice)): + return idx + elif idx is None: + return slice(None) + elif isinstance(idx, Tensor): + if idx.dtype == torch.bool: + idx = torch.where(idx)[0] + elif idx.is_floating_point(): + raise ValueError( + f"Invalid argument dtype {idx.dtype=}. (expected int or bool)" + ) + return idx.tolist() + elif isinstance(idx, Iterable): + return list(idx) + else: + raise TypeError(f"Invalid argument type {type(idx)=}.") + + +class EmptyDataset(Generic[T], Dataset[T]): + def __getitem__(self, *args, **kwargs) -> None: + raise NotImplementedError( + f"Invalid call of getitem for {self.__class__.__name__}." + ) + + def __len__(self) -> int: + return 0 + + +class LambdaDataset(Generic[T], Dataset[T]): + def __init__( + self, + fn: Callable[[int], Any], + length: int, + fn_kws: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__() + self._fn = fn + self._length = length + self._fn_kws = fn_kws if fn_kws is not None else {} + + def __getitem__(self, *args, **kwargs) -> Any: + return self._fn(*args, **kwargs, **self._fn_kws) + + def __len__(self) -> int: + return self._length + + +class Wrapper(Generic[T], Dataset[T]): + """ + Base class for dataset wrappers. + + :param source: The source dataset to wrap. + """ + + def __init__(self, source: T) -> None: + super().__init__() + self._source = source + + # Properties + @property + def source(self) -> T: + return self._source + + # Public methods + def unwrap(self, recursive: bool = True) -> Any: + if not recursive: + return self._source + else: + dset = self._source + while isinstance(dset, Wrapper): + dset = dset.unwrap() + return dset + + # Magic methods + def __getitem__(self, idx: int) -> Any: + return self._source.__getitem__(idx) # type: ignore + + def __len__(self) -> int: + if isinstance(self._source, Sized): + return len(self._source) + else: + raise NotImplementedError( + f"Wrapped dataset {self._source.__class__.__name__} is not Sized." + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({repr(self._source)})" + + +class AACSubset(Wrapper[AACDatasetLike]): + """Similar to torch.utils.data.Subset but for AACDataset classes.""" + + def __init__( + self, + dataset: AACDatasetLike, + indexes: Iterable[int], + overwrite_index: bool = False, + ) -> None: + super().__init__(dataset) + self._indexes = list(indexes) + self._overwrite_index = overwrite_index + + # Public properties + @property + def column_names(self) -> list[str]: + return self._source.column_names + + @property + def indexes(self) -> list[int]: + return self._indexes + + # Public methods + def at( + self, + idx: Union[int, slice, Iterable[int], Tensor, None], + column: Union[str, Iterable[str], None], + ) -> Any: + if idx is None: + idx = slice(None) + if isinstance(idx, Tensor): + idx = idx.tolist() + + if isinstance(idx, Iterable): + idx = list(idx) + local_idx = [self._indexes[idx_i] for idx_i in idx] + else: # int or slice + local_idx = self._indexes[idx] + del idx + item = self._source.at(local_idx, column) + + if not self._overwrite_index: + return item + + if isinstance(item, dict): + index = item.get("index", None) + if isinstance(index, int) or ( + isinstance(index, Iterable) + and all(isinstance(index_i, int) for index_i in index) + ): + item["index"] = local_idx + + elif column == "index": + item = local_idx + + return item + + # Magic methods + def __getitem__(self, idx: Any) -> Any: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + return self.at(idx, column) + + def __len__(self) -> int: + return len(self._indexes) + + +class AACConcat(Wrapper[tuple[AACDatasetLike, ...]]): + """Similar to torch.utils.data.ConcatDataset but for AACDataset classes.""" + + def __init__(self, *datasets: AACDatasetLike) -> None: + super().__init__(datasets) + cumsum = [] + prev_size = 0 + for dset in datasets: + dset_size = len(dset) + cumsum.append(dset_size + prev_size) + prev_size += dset_size + assert len(cumsum) == 0 or cumsum[-1] == len( + self + ), f"Found {cumsum[-1]=} != {len(self)=}." + + self._cumsum = cumsum + + @property + def column_names(self) -> list[str]: + column_names_lst = [dset.column_names for dset in self._source] + column_names = intersect_lists(column_names_lst) + return column_names + + @cache + def _index_to_dset_and_local_indexes(self, idx: int) -> tuple[int, int]: + if not isinstance(idx, int) or idx < 0 or idx >= self._cumsum[-1]: + raise IndexError(f"Invalid index {idx} for {self.__class__.__name__}.") + + local_index = None + dset_idx = None + prevsum = 0 + for i, sum_ in enumerate(self._cumsum): + if idx < sum_: + dset_idx = i + local_index = idx - prevsum + break + prevsum = sum_ + + if local_index is None or dset_idx is None: + raise IndexError( + f"Invalid index {idx} for {self.__class__.__name__}. (found {local_index=} and {dset_idx=})" + ) + + return dset_idx, local_index + + def at( + self, + idx: Union[int, Iterable[int], Tensor, slice, None], + column: Union[str, Iterable[str], None] = None, + ) -> Any: + if idx is None: + idx = slice(None) + if isinstance(idx, slice): + idx = range(len(self))[idx] + if isinstance(idx, Tensor): + idx = idx.tolist() + if column is None: + column = self.column_names + if not isinstance(column, str) and isinstance(column, Iterable): + return {column_i: self.at(idx, column_i) for column_i in column} + + assert isinstance(column, str) + + if isinstance(idx, Iterable): + item = [] + for idx_i in idx: + dset_idx, local_index = self._index_to_dset_and_local_indexes(idx_i) + item_i = self._source[dset_idx].at(local_index, column) + item.append(item_i) + return item + + elif isinstance(idx, int): + dset_idx, local_index = self._index_to_dset_and_local_indexes(idx) + return self._source[dset_idx].at(local_index, column) + + else: + raise TypeError(f"Invalid argument type {idx=}.") + + def __getitem__( + self, + idx: Any, + ) -> Any: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + return self.at(idx, column) # type: ignore + + def __len__(self) -> int: + return sum(map(len, self._source)) + + +class TransformWrapper(Wrapper): + def __init__( + self, + dataset: SizedDatasetLike, + transforms: Union[Callable, Iterable[Callable], None], + index: Union[None, int, str] = None, + default_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + """Wrap a dataset method `getitem` with a transform.""" + if transforms is None: + transforms = [] + elif isinstance(transforms, Callable): + transforms = [transforms] + else: + transforms = list(transforms) + + if default_kwargs is None: + default_kwargs = {} + super().__init__(dataset) + self._transforms = transforms + self._index = index + self._default_kwargs = default_kwargs + + def apply_transform(self, item: Any) -> Any: + for tfm in self._transforms: + item = tfm(item, **self._default_kwargs) + return item + + def __getitem__(self, idx: Any) -> Any: + item = self._source.__getitem__(idx) + if self._index is None: + return self.apply_transform(item) + + elif isinstance(item, MutableMapping): + item[self._index] = self.apply_transform( + item[self._index], **self._default_kwargs + ) + return item + + elif isinstance(item, Iterable): + return tuple( + ( + self.apply_transform(sub_item, **self._default_kwargs) + if i == self._index + else sub_item + ) + for i, sub_item in enumerate(item) + ) + + else: + raise TypeError( + f"Invalid item type {type(item)}. (expected dict or iterable)" + ) + + +class CacheWrap(Wrapper): + def __init__(self, dataset: Any) -> None: + super().__init__(dataset) + + @cache + def __getitem__(self, idx: int) -> tuple: + return self._source.__getitem__(idx) + + @cache + def __len__(self) -> int: + return len(self._source) + + def load_items( + self, verbose: bool = False, desc: str = "Loading dataset..." + ) -> None: + for i in tqdm.trange(len(self), disable=not verbose, desc=desc): + self[i] + + +class DatasetCycle(Wrapper): + def __init__(self, dataset: SizedDatasetLike, target_size: int) -> None: + assert isinstance(dataset, Sized) + assert len(dataset) <= target_size + super().__init__(dataset) + self._target_size = target_size + + def __getitem__(self, idx: int) -> Any: + local_index = idx % len(self._source) + return self._source[local_index] + + def __len__(self) -> int: + return self._target_size + + +class WrapperSampler(Wrapper[AACDatasetLike]): + """Randomly sample each element from source.""" + + def __init__( + self, + source: AACDatasetLike, + size: int, + generator: Union[int, torch.Generator, None] = None, + ) -> None: + assert len(source) >= size + if isinstance(generator, int): + generator = torch.Generator().manual_seed(generator) + super().__init__(source) + self.size = size + self.generator = generator + self.indexes = torch.arange(size) + self.reset_indexes() + + def reset_indexes(self) -> None: + self.indexes = torch.randperm(len(self.source), generator=self.generator)[ + : self.size + ] + + @property + def column_names(self) -> list[str]: + return self._source.column_names + + def at(self, idx: Any, column: Any) -> Any: + assert isinstance( + idx, int + ), f"WrapperSampler does not support non-integer indexes. (found {idx=})" + idx = self.indexes[idx] + return self._source.at(idx, column) + + def __getitem__(self, idx: Any) -> Any: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + return self.at(idx, column) + + def __len__(self) -> int: + return self.size + + +class Duplicate(Wrapper[SizedDatasetLike]): + def __init__(self, source: SizedDatasetLike, maxsize: int) -> None: + super().__init__(source) + self.maxsize = maxsize + + def __getitem__(self, idx: int) -> Any: + idx = idx % len(self._source) + return super().__getitem__(idx) + + def __len__(self) -> int: + return self.maxsize + + +class AACDuplicate(Wrapper[AACDatasetLike]): + def __init__(self, source: AACDatasetLike, maxsize: int) -> None: + super().__init__(source) + self.maxsize = maxsize + + @property + def column_names(self) -> list[str]: + return self._source.column_names + + def at(self, idx: Union[int, Iterable[int], None], column: Any = None) -> Any: + idx = self._map_index(idx) + return self._source.at(idx, column) + + def __getitem__(self, idx: Any) -> Any: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + return self.at(idx, column) + + def __len__(self) -> int: + return self.maxsize + + def _map_index(self, idx: Union[int, Iterable[int], None]) -> Any: + if isinstance(idx, int): + idx = idx % len(self._source) + elif isinstance(idx, Iterable) and all(isinstance(idx_i, int) for idx_i in idx): + idx = [self._map_index(idx_i) for idx_i in idx] + elif idx is None: + idx = self._map_index(range(len(self))) + else: + raise TypeError(f"Invalid argument type {idx=}.") + return idx + + +class DsetTestSample(Dataset): + def __init__(self) -> None: + super().__init__() + self._all_captions = [ + ( + "Cars travel past in the distance as a clock ticks", + "A clock is ticking with traffic in the background", + "An old clock with a pendulum that is swinging back and forth is ticking.", + "An old clock with a pendulum is ticking.", + "The industrial time card clock emits a thick, hallow ticking.", + ), + ( + "Chicks are chirping when a rooster is crowing.", + "Chicks are chirping while a rooster is crowing.", + "Seagulls squawk, then hens and chicks chirp and a rooster crows thrice as waves break against the shore.", + "Waves breaking on a shore and seagulls squawking followed by hens and chicks chirping and a rooster crowing three times", + "Many varieties of bird sing their songs, including a crowing cock.", + ), + ( + "A liquid is completely squeezed out of a tube.", + "A liquid is squeezed out of a tube until it is empty.", + "An air bladder being emptied into a jelly like material.", + "Something is being squeezed out of a bottle with difficulty.", + "The last of the liquid soap is being squeezed out of the bottle.", + ), + ] + + @property + def column_names(self) -> list[str]: + return ["audio", "captions", "index"] + + def at(self, idx: Union[int, slice, Iterable[int]], column: str) -> Any: + if isinstance(idx, slice): + idx = range(len(self))[idx] + if isinstance(idx, Iterable): + return [self.at(i, column) for i in idx] + + if column == "audio": + return torch.full((3,), idx) + elif column == "captions": + return self._all_captions[idx] + else: + raise ValueError(f"Invalid index {idx=}.") + + def __getitem__(self, idx: int) -> dict[str, Any]: + return { + "index": idx, + "audio": self.at(idx, "audio"), + "captions": self.at(idx, "captions"), + } + + def __len__(self) -> int: + return len(self._all_captions) + + +class ZipDataset(Dataset): + def __init__( + self, + *datasets: SizedDatasetLike, + transform: Optional[Callable] = None, + mode: str = "equal", + ) -> None: + if len(datasets) == 0: + raise ValueError(f"Cannot zip without datasets. (found {len(datasets)=})") + if any(len(dset) == 0 for dset in datasets): + raise ValueError( + f"Cannot zip an empty dataset. (found sizes {tuple(len(dset) for dset in datasets)})" + ) + + if mode == "equal": + lens = list(map(len, datasets)) + if any(lens[0] != len_ for len_ in lens): + raise ValueError( + f"Invalid datasets lengths for ZipDatasets. (found {lens=})" + ) + + length = len(datasets[0]) + + elif mode == "min": + length = min(map(len, datasets)) + + elif mode == "max": + length = max(map(len, datasets)) + + else: + MODES = ("equal", "min", "max") + raise ValueError(f"Invalid argument {mode=}. (expected one of {MODES})") + + super().__init__() + self._datasets = datasets + self._transform = transform + self._mode = mode + self._length = length + + def __getitem__(self, idx: int) -> dict[str, Any]: + item = {} + for dset in self._datasets: + item |= dset[idx % len(dset)] + if self._transform is not None: + item = self._transform(item) + return item + + def __len__(self) -> int: + return self._length + + +class AACDatasetFromRaw(AACDatasetLike, Generic[T]): + def __init__(self, all_items: Mapping[str, Iterable[T]]) -> None: + all_items = {k: list(v) for k, v in all_items.items()} + super().__init__() + self._all_items: dict[str, list[T]] = all_items # type: ignore + + @classmethod + def from_iter(cls, all_items: Iterable[Mapping[str, T]]) -> "AACDatasetFromRaw[T]": + all_items = list(all_items) + if len(all_items) == 0: + col_names = {} + else: + col_names = set(all_items[0].keys()) + + if not all(set(col.keys()) == col_names for col in all_items): + raise ValueError("Invalid column names keys.") + + all_items = {k: [col[k] for col in all_items] for k in col_names} # type: ignore + return AACDatasetFromRaw(all_items) # type: ignore + + @property + def column_names(self) -> list[str]: + return list(self._all_items.keys()) + + def at(self, idx: Any, column: Any) -> Any: + if idx is None: + idx = slice(None) + if column is None: + column = self.column_names + + if not isinstance(column, str) and isinstance(column, Iterable): + return {column_i: self.at(idx, column_i) for column_i in column} + + if isinstance(idx, (int, slice)) and column in self._all_items.keys(): + return self._all_items[column][idx] + + if isinstance(idx, slice): + idx = range(len(self))[idx] + + if isinstance(idx, Iterable): + idx = list(idx) + if not all(isinstance(idx_i, int) for idx_i in idx): + raise TypeError( + f"Invalid input type for idx={idx}. (expected Iterable[int], not Iterable[{idx.__class__.__name__}])" + ) + return [self.at(idx_i, column) for idx_i in idx] + + def __getitem__( + self, + idx: Any, + ) -> dict[str, Any]: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + + item = self.at(idx, column) + return item + + def __len__(self) -> int: + if len(self._all_items) > 0: + return len(next(iter(self._all_items.values()))) + else: + return 0 + + +def filter_audio_sizes( + dset: AACDatasetLike, + min_audio_size: float = 0.0, + max_audio_size: float = math.inf, + cache_path: Optional[str] = osp.join("~", ".cache"), + verbose: int = 0, + previous_indexes: Optional[Iterable[int]] = None, + use_duration_column: bool = False, +) -> list[int]: + if verbose >= 2: + len_ = len(dset if previous_indexes is None else list(previous_indexes)) + pylog.debug( + f"Loading durations from {len_} audio files... (with {cache_path=} and {use_duration_column=})" + ) + + if use_duration_column and "duration" in dset.column_names: + durations = dset.at(previous_indexes, "duration") + else: + fpaths = dset.at(previous_indexes, "fpath") + if cache_path is not None: + infos = disk_cache(load_audio_metadata, fpaths, cache_path=cache_path) + else: + infos = load_audio_metadata(fpaths) + durations = [(info.num_frames / info.sample_rate) for info in infos.values()] + + indexes = [ + i + for i, duration in enumerate(durations) + if min_audio_size <= duration <= max_audio_size + ] + n_excluded = len(dset) - len(indexes) + if verbose >= 1: + pylog.info( + f"Exclude {n_excluded}/{len(dset)} files with audio size not in [{min_audio_size}, {max_audio_size}] seconds." + ) + if verbose >= 2: + lim = 10 + excluded_indexes = list(sorted(set(range(len(dset))).difference(indexes))) + pylog.debug(f"Show first {lim} excluded indexes: {excluded_indexes[:lim]}.") + return indexes + + +def load_audio_metadata( + fpaths: list[str], +) -> dict[str, AudioMetaData]: + infos = { + fpath: torchaudio.info(fpath) # type: ignore + for fpath in tqdm.tqdm( + fpaths, + desc=f"Loading audio metadata from {len(fpaths)} files...", + disable=pylog.level >= logging.INFO, + ) + } + return infos + + +def intersect_lists(lst_of_lst: list[list[T]]) -> list[T]: + if len(lst_of_lst) <= 0: + return [] + out = lst_of_lst[0] + for lst_i in lst_of_lst[1:]: + out = [name for name in out if name in lst_i] + if len(out) == 0: + break + return out + + +class Cacher(Wrapper): + def __init__( + self, + source: AACDatasetLike, + cache_keys: Optional[dict[str, str]] = None, + ) -> None: + super().__init__(source) + + if cache_keys is None: + cache_keys = {} + if not all(v in ("get_raw", "get") for v in cache_keys.values()): + raise ValueError + + null_value = "NULL" + + self._cache_keys = cache_keys + self._caches = { + k: [null_value for _ in range(len(source))] for k in cache_keys.keys() + } + self._null_value = null_value + + def at(self, idx: int, column: str) -> Any: + return self._get_cache(idx, column, "get_field") + + def _get_cache(self, idx: int, column: str, type_: str) -> Any: + if self._cache_keys.get(column) == type_: + cached_value = self._caches[column][idx] + if cached_value != self._null_value: + return cached_value + else: + value = self._source.at(idx, column) + self._caches[column][idx] = value + return value + else: + value = self._source.at(idx, column) + return value + + +class DatasetList(Dataset[T]): + def __init__( + self, + items: Iterable[T], + transform: Optional[Callable[[T], Any]] = None, + ) -> None: + if not isinstance(items, Sequence): + items = list(items) + + super().__init__() + self._items = items + self._transform = transform + + def __getitem__(self, idx: int) -> Any: + item = self._items[idx] + if self._transform is not None: + item = self._transform(item) + return item + + def __len__(self) -> int: + return len(self._items) + + +class DebugTracker(Wrapper): + def __init__(self, source: SizedDatasetLike) -> None: + super().__init__(source) + self._delta_sum = 0.0 + self._delta_count = 0 + + def __getitem__(self, idx: int) -> Any: + start = time.perf_counter() + item = super().__getitem__(idx) + end = time.perf_counter() + self._delta_sum += end - start + self._delta_count += 1 + return item + + def get_average(self) -> float: + if self._delta_count == 0: + return 0.0 + else: + return self._delta_sum / self._delta_count + + +class AACSelectColumnsWrapper(Wrapper[AACDatasetLike]): + """Wrapper to filter columns in AACDatasetLike. + + ```python + >>> dset = ... + >>> dset = AACSelectColumnsWrapper(dset, include=("captions",)) + >>> dset[0] + ... {"captions": ...} + ``` + """ + + DEFAULT_VAL = None + + def __init__( + self, + source: AACDatasetLike, + /, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + use_default: bool = True, + ) -> None: + if include is None: + not_found = [] + else: + not_found = [name for name in include if name not in source.column_names] + + if len(not_found) > 0: + warn_once( + f"Cannot find {len(not_found)} column(s) {not_found} in {source} dataset. (found only {source.column_names})", + pylog, + ) + + column_names = [ + name for name in source.column_names if pass_filter(name, include, exclude) + ] + if use_default: + column_names += not_found + super().__init__(source) + self._column_names = column_names + self._use_default = use_default + self._not_found = not_found + + @property + def column_names(self) -> list[str]: + return self._column_names + + def at(self, idx: Any, column: Union[str, Iterable[str], None]) -> Any: + if column is None: + column = self.column_names + + if isinstance(column, str): + if column not in self.column_names: + raise ValueError( + f"Invalid argument {column=}. (expected one of {tuple(self.column_names)})" + ) + if self._use_default and column in self._not_found: + if isinstance(idx, Tensor): + if idx.dtype == torch.bool: + idx = torch.where(idx) + elif idx.is_floating_point(): + raise ValueError( + f"Invalid argument {idx=}. (expected int or bool tensor)" + ) + idx = idx.tolist() + + if isinstance(idx, int): + return self.DEFAULT_VAL + elif isinstance(idx, Iterable): + idx = list(idx) + return [self.DEFAULT_VAL] * len(idx) + else: + raise NotImplementedError( + f"Calling method at() with {self._use_default=} on a column with {type(idx)=} is currently not supported. (with {column=} and {self=})" + ) + else: + return self._source.at(idx, column) + + elif isinstance(column, Iterable): + if not all(column_i in self.column_names for column_i in column): + raise ValueError( + f"Invalid argument {column=}. (expected one of {tuple(self.column_names)})" + ) + + return {column_i: self.at(idx, column_i) for column_i in column} + else: + raise TypeError( + f"Invalid argument type {column=}. (expected str or Iterable[str])" + ) + + def __getitem__(self, idx: Any) -> dict[str, Any]: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + item = self.at(idx, column) + return item + + +class AACReplaceColumnWrapper(Wrapper[AACDatasetLike]): + def __init__( + self, source: AACDatasetLike, target_column: str, values: Iterable[Any] + ) -> None: + if hasattr(source, "_transform"): + raise ValueError + values = list(values) + super().__init__(source) + self._target_column = target_column + self._values = values + + @property + def column_names(self) -> list[str]: + return self.source.column_names + + def at( + self, + idx: Union[int, Iterable[int], slice, None], + column: Union[str, Iterable[str], None], + ) -> Any: + if idx is None: + idx = slice(None) + if column is None: + column = self.column_names + + if not isinstance(column, str) and isinstance(column, Iterable): + return {column_i: self.at(idx, column_i) for column_i in column} + + if isinstance(idx, (int, slice)) and column == self._target_column: + return self._values[idx] + + if isinstance(idx, slice): + idx = range(len(self))[idx] + + if isinstance(idx, Iterable): + idx = list(idx) + if not all(isinstance(idx_i, int) for idx_i in idx): + raise TypeError( + f"Invalid input type for idx={idx}. (expected Iterable[int], not Iterable[{idx.__class__.__name__}])" + ) + return [self.at(idx_i, column) for idx_i in idx] + + assert column != self._target_column + return self.source.at(idx, column) + + def __getitem__(self, idx: int) -> dict[str, Any]: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + return self.at(idx, column) + + +class PostSelectColumnsWrapper(Wrapper[SizedDatasetLike]): + def __init__( + self, + source: SizedDatasetLike, + column_names: Iterable[str], + /, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + ) -> None: + column_names = [ + name for name in column_names if pass_filter(name, include, exclude) + ] + super().__init__(source) + self._column_names = column_names + + @property + def column_names(self) -> list[str]: + return self._column_names + + def __getitem__(self, idx: int) -> dict[str, Any]: + item = self._source[idx] + item = {column: item[column] for column in self._column_names} + return item + + +class AACTransformWrapper(Wrapper[AACDatasetLike]): + def __init__( + self, + source: AACDatasetLike, + transforms: dict[str, Callable[[Any, int], Any]], + verbose: int = 0, + ) -> None: + super().__init__(source) + self._transforms = transforms + self._verbose = verbose + + @property + def column_names(self) -> list[str]: + return self._source.column_names + + def at(self, idx: Any, column: Any) -> Any: + if idx is None: + idx = slice(None) + elif isinstance(idx, Tensor): + idx = idx.tolist() + if column is None: + column = self.column_names + + if not isinstance(column, str) and isinstance(column, Iterable): + return {column_i: self.at(idx, column_i) for column_i in column} + assert isinstance(column, str) + + transform = self._transforms.get(column) + if transform is None: + return self._source.at(idx, column) + + if isinstance(idx, slice): + idx = range(len(self))[idx] + + if isinstance(idx, Iterable): + idx = list(idx) + if not all(isinstance(idx_i, int) for idx_i in idx): + raise TypeError( + f"Invalid input type for idx={idx}. (expected Iterable[int], not Iterable[{idx.__class__.__name__}])" + ) + + values = self._source.at(idx, column) + values = [transform(value, idx_i) for value, idx_i in zip(values, idx)] + return values + + elif isinstance(idx, int): + value = self._source.at(idx, column) + return transform(value, idx) + + else: + raise TypeError(f"Invalid argument type {type(idx)=}.") + + def __getitem__(self, idx: Any) -> Any: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + return self.at(idx, column) + + def __len__(self) -> int: + return len(self._source) + + +class DummyAACDataset(AACDatasetLike): + def __init__(self, size: int = 10) -> None: + super().__init__() + self.size = size + + @property + def column_names(self) -> list[str]: + return ["index", "value"] + + def at(self, idx, column) -> Any: + if idx is None: + idx = slice(None) + if isinstance(idx, slice): + idx = range(len(self))[idx] + + if column is None: + column = self.column_names + if not isinstance(column, str) and isinstance(column, Iterable): + return {col_i: self.at(idx, col_i) for col_i in column} + + if isinstance(idx, Iterable): + return [self.at(idx_i, column) for idx_i in idx] + + if column == "value": + return f"value_{idx}" + elif column == "index": + return idx + else: + raise ValueError(f"Invalid argument {column=}.") + + def __getitem__(self, idx: Any) -> Any: + if ( + isinstance(idx, tuple) + and len(idx) == 2 + and (isinstance(idx[1], (str, Iterable)) or idx[1] is None) + ): + idx, column = idx + else: + column = None + return self.at(idx, column) + + def __len__(self) -> int: + return self.size diff --git a/src/conette/huggingface/config.py b/src/conette/huggingface/config.py index 2103750c4..bdef56076 100644 --- a/src/conette/huggingface/config.py +++ b/src/conette/huggingface/config.py @@ -1,10 +1,14 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import logging + from typing import Any, Iterable, Optional from transformers import PretrainedConfig +pylog = logging.getLogger(__name__) + class CoNeTTEConfig(PretrainedConfig): def __init__( @@ -47,6 +51,11 @@ def __init__( tokenizer_state: Optional[dict[str, Any]] = None, **kwargs, ) -> None: + if len(kwargs) > 0: + pylog.warning( + f"Unknown {len(kwargs)} keywords arguments for {self.__class__.__name__}. (found {tuple(kwargs.keys())})" + ) + betas = list(betas) # type: ignore super().__init__() self.task_mode = task_mode diff --git a/src/conette/huggingface/model.py b/src/conette/huggingface/model.py index 632c27f97..55926ddeb 100644 --- a/src/conette/huggingface/model.py +++ b/src/conette/huggingface/model.py @@ -3,7 +3,7 @@ import logging -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Optional, TypedDict, Union import pickle import torch @@ -17,11 +17,21 @@ from conette.nn.functional.get import get_device from conette.pl_modules.conette import CoNeTTEPLM from conette.tokenization.aac_tokenizer import AACTokenizer +from conette.transforms.audioset_labels import probs_to_labels pylog = logging.getLogger(__name__) +class CoNeTTEOutput(TypedDict): + cands: list[str] + tasks: list[str] + preds: Tensor + lprobs: Tensor + tags: list[list[str]] + tags_probs: Tensor + + class CoNeTTEModel(PreTrainedModel): """CoNeTTE PreTrainedModel for inference.""" @@ -30,8 +40,10 @@ def __init__( config: CoNeTTEConfig, device: Union[str, torch.device, None] = "auto", inference: bool = True, + offline: bool = False, + model_override: Optional[CoNeTTEPLM] = None, ) -> None: - setup_other_models() + setup_other_models(offline) if config.tokenizer_state is None: tokenizer = AACTokenizer() @@ -39,36 +51,39 @@ def __init__( tokenizer = AACTokenizer.from_txt_state(config.tokenizer_state) preprocessor = CoNeTTEPreprocessor(verbose=config.verbose) - model = CoNeTTEPLM( - task_mode=config.task_mode, - task_names=config.task_names, - gen_test_cands=config.gen_test_cands, - label_smoothing=config.label_smoothing, - gen_val_cands=config.gen_val_cands, - mixup_alpha=config.mixup_alpha, - proj_name=config.proj_name, - min_pred_size=config.min_pred_size, - max_pred_size=config.max_pred_size, - beam_size=config.beam_size, - nhead=config.nhead, - d_model=config.d_model, - num_decoder_layers=config.num_decoder_layers, - decoder_dropout_p=config.decoder_dropout_p, - dim_feedforward=config.dim_feedforward, - acti_name=config.acti_name, - optim_name=config.optim_name, - lr=config.lr, - weight_decay=config.weight_decay, - betas=config.betas, - eps=config.eps, - use_custom_wd=config.use_custom_wd, - sched_name=config.sched_name, - sched_n_steps=config.sched_n_steps, - sched_interval=config.sched_interval, - sched_freq=config.sched_freq, - train_tokenizer=tokenizer, - verbose=config.verbose, - ) + if model_override is not None: + model = model_override + else: + model = CoNeTTEPLM( + task_mode=config.task_mode, + task_names=config.task_names, + gen_test_cands=config.gen_test_cands, + label_smoothing=config.label_smoothing, + gen_val_cands=config.gen_val_cands, + mixup_alpha=config.mixup_alpha, + proj_name=config.proj_name, + min_pred_size=config.min_pred_size, + max_pred_size=config.max_pred_size, + beam_size=config.beam_size, + nhead=config.nhead, + d_model=config.d_model, + num_decoder_layers=config.num_decoder_layers, + decoder_dropout_p=config.decoder_dropout_p, + dim_feedforward=config.dim_feedforward, + acti_name=config.acti_name, + optim_name=config.optim_name, + lr=config.lr, + weight_decay=config.weight_decay, + betas=config.betas, + eps=config.eps, + use_custom_wd=config.use_custom_wd, + sched_name=config.sched_name, + sched_n_steps=config.sched_n_steps, + sched_interval=config.sched_interval, + sched_freq=config.sched_freq, + train_tokenizer=tokenizer, + verbose=config.verbose, + ) super().__init__(config) self.config: CoNeTTEConfig @@ -162,22 +177,27 @@ def forward( sr: Union[None, int, Iterable[int]] = None, x_shapes: Union[Tensor, None, list[Size]] = None, preprocess: bool = True, + threshold: Union[float, Tensor] = 0.3, # Beam search options task: Union[str, list[str], None] = None, beam_size: Optional[int] = None, min_pred_size: Optional[int] = None, max_pred_size: Optional[int] = None, forbid_rep_mode: Optional[str] = None, - ) -> dict[str, Any]: + ) -> CoNeTTEOutput: # Preprocessing (load data + encode features) if preprocess: batch = self.preprocessor(x, sr, x_shapes) + clip_probs = batch.pop("clip_probs") + tags = probs_to_labels(clip_probs, threshold, True, self.config.verbose) else: assert isinstance(x, Tensor) and isinstance(x_shapes, Tensor) batch: dict[str, Any] = { "audio": x.to(self.device), "audio_shape": x_shapes.to(self.device), } + clip_probs = None + tags = None # Add task information to batch bsize = len(batch["audio"]) @@ -222,6 +242,10 @@ def forward( outs = self.model(batch, **kwds) outs["tasks"] = tasks + if clip_probs is not None and tags is not None: + outs["tags_probs"] = clip_probs + outs["tags"] = tags + return outs def __call__( @@ -230,17 +254,21 @@ def __call__( x: Union[Tensor, str, Iterable[str], Iterable[Tensor]], sr: Union[None, int, Iterable[int]] = None, x_shapes: Union[Tensor, None, list[Size]] = None, + preprocess: bool = True, + threshold: Union[float, Tensor] = 0.3, # Beam search options task: Union[str, list[str], None] = None, beam_size: Optional[int] = None, min_pred_size: Optional[int] = None, max_pred_size: Optional[int] = None, forbid_rep_mode: Optional[str] = None, - ) -> dict[str, Any]: + ) -> CoNeTTEOutput: return super().__call__( x=x, sr=sr, x_shapes=x_shapes, + preprocess=preprocess, + threshold=threshold, task=task, beam_size=beam_size, min_pred_size=min_pred_size, diff --git a/src/conette/huggingface/preprocessor.py b/src/conette/huggingface/preprocessor.py index af6eb09a1..0f88155c9 100644 --- a/src/conette/huggingface/preprocessor.py +++ b/src/conette/huggingface/preprocessor.py @@ -61,6 +61,7 @@ def forward( frame_embs = outs["frame_embs"] frame_embs_lens = outs["frame_embs_lens"] + clip_probs = outs["clipwise_output"] # Transpose (bsize, feat_size, time) -> (bsize, time, features=768) frame_embs = frame_embs.transpose(1, 2) @@ -69,7 +70,11 @@ def forward( ) del frame_embs_lens - batch = {"audio": frame_embs, "audio_shape": audio_shape} + batch = { + "audio": frame_embs, + "audio_shape": audio_shape, + "clip_probs": clip_probs, + } return batch def _load(self, path: str) -> tuple[Tensor, int]: diff --git a/src/conette/huggingface/setup.py b/src/conette/huggingface/setup.py index 9cdf947ea..1577d93c1 100644 --- a/src/conette/huggingface/setup.py +++ b/src/conette/huggingface/setup.py @@ -9,6 +9,8 @@ import nltk +from conette.transforms.audioset_labels import download_audioset_mapping + pylog = logging.getLogger(__name__) @@ -32,3 +34,6 @@ def setup_other_models(offline: bool = False, verbose: int = 0) -> None: # Download stopwords list for constrained beam search nltk_model = "stopwords" nltk.download(nltk_model, quiet=verbose <= 0) + + # Download Audioset mapping ids in cache + download_audioset_mapping(verbose) diff --git a/src/conette/metrics/__init__.py b/src/conette/metrics/__init__.py new file mode 100644 index 000000000..faa18be5b --- /dev/null +++ b/src/conette/metrics/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- diff --git a/src/conette/metrics/classes/__init__.py b/src/conette/metrics/classes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/conette/metrics/classes/all_metrics.py b/src/conette/metrics/classes/all_metrics.py new file mode 100644 index 000000000..bc3442430 --- /dev/null +++ b/src/conette/metrics/classes/all_metrics.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import tempfile +import time + +from typing import Iterable, Optional, Union + +import torch + +from torch import Tensor +from torch.nn.parameter import Parameter + +from aac_metrics.classes.bert_score_mrefs import BERTScoreMRefs +from aac_metrics.classes.bleu import BLEU +from aac_metrics.classes.meteor import METEOR +from aac_metrics.classes.rouge_l import ROUGEL +from aac_metrics.classes.evaluate import Evaluate +from aac_metrics.classes.fense import FENSE +from aac_metrics.classes.spider import SPIDEr +from aac_metrics.functional.spider_fl import _spider_fl_from_outputs + +from conette.metrics.classes.diversity import Diversity +from conette.metrics.classes.new_words import NewWords +from conette.metrics.classes.text_stats import TextStats +from conette.nn.functional.get import get_device + + +pylog = logging.getLogger(__name__) + + +class AllMetrics(Evaluate): + def __init__( + self, + preprocess: bool = True, + device: Union[str, torch.device, None] = "auto", + cache_path: str = "~/.cache", + java_path: str = "java", + tmp_path: str = tempfile.gettempdir(), + meteor_java_max_memory: str = "2G", + spice_n_threads: Optional[int] = None, + spice_java_max_memory: str = "8G", + spice_timeout: Union[None, int, Iterable[int]] = None, + train_vocab: Optional[Iterable[str]] = None, + verbose: int = 0, + metrics_names: Union[None, Iterable[str], str] = None, + ) -> None: + device = get_device(device) + + if verbose >= 2: + pylog.debug(f"Use {device=} for metrics.") + + if metrics_names is None: + metrics_names = [ + "bert_score", + "diversity", + "text_stats", + "new_words", + "bleu_1", + "bleu_2", + "bleu_3", + "bleu_4", + "meteor", + "rouge_l", + "spider", + "fense", + "spider_fl", + ] + elif isinstance(metrics_names, str): + metrics_names = [metrics_names] + else: + metrics_names = list(metrics_names) + + return_all_scores: bool = True + metrics = [] + + if "bert_score" in metrics_names: + metric = BERTScoreMRefs(return_all_scores, device=device, verbose=verbose) + metrics.append(metric) + + if "diversity" in metrics_names: + metric = Diversity(return_all_scores, n_max=3) + metrics.append(metric) + + if "text_stats" in metrics_names: + metric = TextStats(return_all_scores) + metrics.append(metric) + + if "new_words" in metrics_names and train_vocab is not None: + metric = NewWords(return_all_scores, train_vocab) + metrics.append(metric) + + if "bleu_1" in metrics_names: + metric = BLEU(return_all_scores, n=1) + metrics.append(metric) + + if "bleu_2" in metrics_names: + metric = BLEU(return_all_scores, n=2) + metrics.append(metric) + + if "bleu_3" in metrics_names: + metric = BLEU(return_all_scores, n=3) + metrics.append(metric) + + if "bleu_4" in metrics_names: + metric = BLEU(return_all_scores, n=4) + metrics.append(metric) + + if "meteor" in metrics_names: + metric = METEOR( + return_all_scores, + cache_path=cache_path, + java_path=java_path, + java_max_memory=meteor_java_max_memory, + verbose=verbose, + ) + metrics.append(metric) + + if "rouge_l" in metrics_names: + metric = ROUGEL(return_all_scores) + metrics.append(metric) + + if "spider" in metrics_names: + metric = SPIDEr( + return_all_scores, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + n_threads=spice_n_threads, + java_max_memory=spice_java_max_memory, + timeout=spice_timeout, + verbose=verbose, + ) + metrics.append(metric) + + if "fense" in metrics_names: + metric = FENSE( + return_all_scores, device=device, return_probs=True, verbose=verbose + ) + metrics.append(metric) + + super().__init__( + preprocess=preprocess, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + verbose=verbose, + metrics=metrics, + ) + self._verbose = verbose + self._metrics_names = metrics_names + + self.register_parameter( + "placeholder", Parameter(torch.empty((0,), device=device)) + ) + self.placeholder: Parameter + + def compute(self) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + outs = super().compute() + + if "spider_fl" in self._metrics_names: + name = "SPIDEr-FL" + if self._verbose >= 1: + pylog.info(f"Computing {name} to outputs...") + + start = time.perf_counter() + outs = _spider_fl_from_outputs(outs, outs) + end = time.perf_counter() + + if self._verbose >= 1: + duration = end - start + pylog.info(f"Metric {name} computed in {duration:.2f}s.") + + return outs + + def extra_repr(self) -> str: + return f"len={len(self)}, device={self.device}" + + @property + def device(self) -> torch.device: + return self.placeholder.device diff --git a/src/conette/metrics/classes/bert_score_mrefs.ign.py b/src/conette/metrics/classes/bert_score_mrefs.ign.py new file mode 100644 index 000000000..60cbc8737 --- /dev/null +++ b/src/conette/metrics/classes/bert_score_mrefs.ign.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Union + +import torch + +from torch import nn, Tensor +from torchmetrics.text.bert import _DEFAULT_MODEL + +from aac_metrics.classes.base import AACMetric + +from conette.metrics.functional.bert_score_mrefs import ( + bert_score_mrefs, + _load_model_and_tokenizer, +) + + +class BERTScoreMRefs(AACMetric): + """BERTScore metric which supports multiple references. + + The implementation is based on the bert_score implementation of torchmetrics. + + - Paper: https://arxiv.org/pdf/1904.09675.pdf + + For more information, see :func:`~aac_metrics.functional.bert_score.bert_score_mrefs`. + """ + + full_state_update = False + higher_is_better = True + is_differentiable = False + + min_value = 0.0 + max_value = 1.0 + + def __init__( + self, + return_all_scores: bool = True, + model: Union[str, nn.Module] = _DEFAULT_MODEL, + device: Union[str, torch.device, None] = "auto", + batch_size: int = 32, + num_threads: int = 0, + max_length: int = 64, + reset_state: bool = True, + verbose: int = 0, + ) -> None: + model, tokenizer = _load_model_and_tokenizer( + model, None, device, reset_state, verbose + ) + + super().__init__() + self._return_all_scores = return_all_scores + self._model = model + self._tokenizer = tokenizer + self._device = device + self._batch_size = batch_size + self._num_threads = num_threads + self._max_length = max_length + self._reset_state = reset_state + self._verbose = verbose + + self._candidates = [] + self._mult_references = [] + + def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bert_score_mrefs( + self._candidates, + self._mult_references, + self._return_all_scores, + self._model, + self._tokenizer, + self._device, + self._batch_size, + self._num_threads, + self._max_length, + self._reset_state, + self._verbose, + ) + + def get_output_names(self) -> tuple[str, ...]: + return ( + "bert_score.precision", + "bert_score.recalll", + "bert_score.f1", + ) + + def reset(self) -> None: + self._candidates = [] + self._mult_references = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: list[list[str]], + ) -> None: + self._candidates += candidates + self._mult_references += mult_references diff --git a/src/conette/metrics/classes/diversity.py b/src/conette/metrics/classes/diversity.py new file mode 100644 index 000000000..160a6db07 --- /dev/null +++ b/src/conette/metrics/classes/diversity.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math + +from typing import Callable, Union + +import torch + +from torch import Tensor + +from aac_metrics.classes.base import AACMetric + +from conette.metrics.functional.diversity import ( + _diversity_compute, + _diversity_update, +) + + +class Diversity(AACMetric): + full_state_update = False + higher_is_better = True + is_differentiable = False + + min_value = 0.0 + max_value = math.inf + + def __init__( + self, + return_all_scores: bool = True, + n_max: int = 1, + cumulative: bool = False, + use_ngram_count: bool = True, + seed: Union[None, int, torch.Generator] = 123, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__() + self._return_all_scores = return_all_scores + self._n_max = n_max + self._cumulative = cumulative + self._use_ngram_count = use_ngram_count + self._seed = seed + self._tokenizer = tokenizer + + self._tok_cands = [] + self._tok_mrefs = [] + + # Metric methods + def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return _diversity_compute( + tok_cands=self._tok_cands, + tok_mrefs=self._tok_mrefs, + return_all_scores=self._return_all_scores, + n_max=self._n_max, + cumulative=self._cumulative, + use_ngram_count=self._use_ngram_count, + seed=self._seed, + ) + + def get_output_names(self) -> tuple[str, ...]: + return ( + f"sents_div{self._n_max}.cands", + f"sents_div{self._n_max}.mrefs", + f"sents_div{self._n_max}.ratio", + f"corpus_div{self._n_max}.cands", + f"corpus_div{self._n_max}.mrefs", + f"corpus_div{self._n_max}.ratio", + ) + + def reset(self) -> None: + self._tok_cands = [] + self._tok_mrefs = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: list[list[str]], + ) -> None: + self._tok_cands, self._tok_mrefs = _diversity_update( + candidates, + mult_references, + self._tokenizer, + self._tok_cands, + self._tok_mrefs, + ) diff --git a/src/conette/metrics/classes/jaccard.py b/src/conette/metrics/classes/jaccard.py new file mode 100644 index 000000000..c1fe08f03 --- /dev/null +++ b/src/conette/metrics/classes/jaccard.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Any, Optional, Union + +from nltk.stem import SnowballStemmer +from torch import Tensor +from torchmetrics import Metric + +from conette.metrics.functional.jaccard import jaccard + + +class Jaccard(Metric): + """Jaccard similarity, also known as "intersection over union".""" + + is_differentiable = False + higher_is_better = True + full_state_update = False + + min_value = 0.0 + max_value = 1.0 + + def __init__( + self, + return_all_scores: bool = True, + stemmer_lang: Optional[str] = "english", + ) -> None: + if stemmer_lang is not None: + stemmer = SnowballStemmer(stemmer_lang) + else: + stemmer = None + + super().__init__() + self.return_all_scores = return_all_scores + self.stemmer_lang = stemmer_lang + self.stemmer = stemmer + + self.candidates = [] + self.mult_references = [] + + # Metric methods + def compute(self) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: + return jaccard( + self.candidates, + self.mult_references, + self.return_all_scores, + self.stemmer, + ) + + def get_output_names(self) -> tuple[str, ...]: + return ("jaccard",) + + def reset(self) -> None: + self.candidates = [] + self.mult_references = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: list[list[str]], + ) -> None: + self.candidates += candidates + self.mult_references += mult_references + + # Magic methods + def __getstate__(self) -> dict[str, Any]: + return { + "return_all_scores": self.return_all_scores, + "stemmer_lang": self.stemmer_lang, + "candidates": self.candidates, + "mult_references": self.mult_references, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self.return_all_scores = state["return_all_scores"] + self.stemmer_lang = state["stemmer_lang"] + self.candidates = state["candidates"] + self.mult_references = state["mult_references"] + + if self.stemmer_lang is not None: + stemmer = SnowballStemmer(self.stemmer_lang) + else: + stemmer = None + self.stemmer = stemmer diff --git a/src/conette/metrics/classes/new_words.py b/src/conette/metrics/classes/new_words.py new file mode 100644 index 000000000..163bb6089 --- /dev/null +++ b/src/conette/metrics/classes/new_words.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math + +from typing import Iterable, Union + +from torch import Tensor +from torchmetrics import Metric + +from conette.metrics.functional.new_words import new_words + + +class NewWords(Metric): + """Jaccard similarity, also known as "intersection over union".""" + + is_differentiable = False + higher_is_better = True + full_state_update = False + + min_value = 0.0 + max_value = math.inf + + def __init__( + self, + return_all_scores: bool = True, + train_vocab: Iterable[str] = (), + ) -> None: + train_vocab = dict.fromkeys(train_vocab) + + super().__init__() + self.return_all_scores = return_all_scores + self.train_vocab = train_vocab + + self.candidates = [] + self.mult_references = [] + + # Metric methods + def compute(self) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: + return new_words( + self.candidates, + self.mult_references, + self.return_all_scores, + self.train_vocab, + ) + + def get_output_names(self) -> tuple[str, ...]: + return ("new_words",) + + def reset(self) -> None: + self.candidates = [] + self.mult_references = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: list[list[str]], + ) -> None: + self.candidates += candidates + self.mult_references += mult_references diff --git a/src/conette/metrics/classes/null.py b/src/conette/metrics/classes/null.py new file mode 100644 index 000000000..e8a2e8a86 --- /dev/null +++ b/src/conette/metrics/classes/null.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from torchmetrics import Metric + + +class NullMetric(Metric): + """Placeholder Metric. Method `compute` always returns 0.""" + + is_differentiable = False + higher_is_better = True + full_state_update = False + + min_value = 0.0 + max_value = 0.0 + + # Metric methods + def update(self, *args, **kwargs) -> None: + pass + + def compute(self) -> float: + return 0.0 diff --git a/src/conette/metrics/classes/self_bleu.py b/src/conette/metrics/classes/self_bleu.py new file mode 100644 index 000000000..0bc1dcec4 --- /dev/null +++ b/src/conette/metrics/classes/self_bleu.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Optional, Union + +import torch + +from torch import Tensor +from torchmetrics import Metric + +from conette.metrics.functional.self_bleu import self_bleu + + +class SelfBleuCands(Metric): + is_differentiable = False + higher_is_better = False + full_state_update = False + + min_value = 0.0 + max_value = 1.0 + + def __init__( + self, + max_ngram_sizes: int = 4, + max_refs: Optional[int] = None, + generator: Union[None, int, torch.Generator] = 1234, + ) -> None: + super().__init__() + self.max_ngram_sizes = max_ngram_sizes + self.max_refs = max_refs + self.generator = generator + self.candidates = [] + + # Metric methods + def compute(self) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + return self_bleu( + self.candidates, + self.max_ngram_sizes, + self.max_refs, + self.generator, + ) + + def reset(self) -> None: + self.candidates = [] + return super().reset() + + def update( + self, + candidates: list[list[str]], + mult_references: list[list[list[str]]], + ) -> None: + self.candidates += candidates + + +class SelfBleuMRefs(Metric): + is_differentiable = False + higher_is_better = False + full_state_update = False + + def __init__( + self, + max_ngram_sizes: int = 4, + max_refs: Optional[int] = None, + generator: Union[None, int, torch.Generator] = 1234, + ) -> None: + super().__init__() + self.max_ngram_sizes = max_ngram_sizes + self.max_refs = max_refs + self.generator = generator + self.references_flat = [] + + # Metric methods + def compute(self) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + return self_bleu( + self.references_flat, + self.max_ngram_sizes, + self.max_refs, + self.generator, + ) + + def reset(self) -> None: + self.references_flat = [] + return super().reset() + + def update( + self, + candidates: list[list[str]], + mult_references: list[list[list[str]]], + ) -> None: + references_flat = [ref for refs in mult_references for ref in refs] + self.references_flat += references_flat diff --git a/src/conette/metrics/classes/tensor.py b/src/conette/metrics/classes/tensor.py new file mode 100644 index 000000000..3fc715ed9 --- /dev/null +++ b/src/conette/metrics/classes/tensor.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Iterable, Optional, Union + +import torch + +from torch import nn, Tensor +from torchmetrics.metric import Metric + +from conette.nn.functional.mask import tensor_to_lengths, tensor_to_non_pad_mask + + +class MeanPredLen(Metric): + is_differentiable: Optional[bool] = False + higher_is_better: Optional[bool] = None + full_state_update: bool = True + + def __init__(self, eos_id: int) -> None: + super().__init__() + self.eos_id = eos_id + + self.sum_lens: Tensor + self.total: Tensor + self.add_state("sum_lens", default=torch.zeros(()), dist_reduce_fx="sum") + self.add_state("total", default=torch.zeros(()), dist_reduce_fx="sum") + + # Metric methods + def update( + self, + preds: Tensor, + ) -> None: + assert preds.ndim == 2 + lengths = tensor_to_lengths(preds, end_value=self.eos_id) + + self.sum_lens += lengths.sum() + self.total += lengths.shape[0] + + def compute(self) -> Tensor: + return self.sum_lens / self.total + + +class TensorDiversity1(Metric): + """Compute Diversity on 1-gram (also called Type-Token Ratio) on encoded sentences tensors.""" + + is_differentiable: Optional[bool] = False + higher_is_better: Optional[bool] = None + full_state_update: bool = True + + def __init__( + self, + eos: int, + excluded: Union[Iterable[int], Tensor] = (), + return_sents_scores: bool = False, + ) -> None: + if isinstance(excluded, Tensor): + excluded = excluded.flatten().tolist() + super().__init__() + self._eos = eos + self._excluded = list(excluded) + self._return_sents_scores = return_sents_scores + + self.scores: list[Tensor] + self.add_state("scores", default=[], dist_reduce_fx=None) + + # Metric methods + def update( + self, + preds: Tensor, + ) -> None: + """ + :param preds: Tensor of shape (bsize, N) + """ + assert preds.ndim == 2 + preds_mask = tensor_to_non_pad_mask(preds, end_value=self._eos) + + scores = torch.empty((preds.shape[0],), dtype=torch.float, device=preds.device) + for i, (pred, mask) in enumerate(zip(preds, preds_mask)): + for value in self._excluded: + mask &= pred.ne(value) + vocab = pred[mask].unique() + scores[i] = vocab.shape[0] / max(mask.sum().item(), 1) + + self.scores.append(scores) + + def compute(self) -> Tensor: + if len(self.scores) > 0: + scores = torch.cat(self.scores) + if self._return_sents_scores: + return scores + else: + return scores.mean() + else: + return torch.zeros((), dtype=torch.float, device=self.device) + + +class GlobalTensorVocabUsage(nn.Module): + r"""Global Vocab Usage. + + Returns \frac{|hyp\_vocab|}{|ref\_vocab|} + """ + + def __init__(self, ignored_indexes: Union[Iterable[int], Tensor] = ()) -> None: + if isinstance(ignored_indexes, Tensor): + ignored_indexes = ignored_indexes.flatten().tolist() + super().__init__() + self._ignored_indexes = list(ignored_indexes) + self._preds_vocab = None + self._captions_vocab = None + + # Metric methods + def reset(self) -> None: + self._preds_vocab = None + self._captions_vocab = None + + def forward(self, preds: Tensor, captions: Tensor) -> float: + """ + :param preds: (bsize, pred_len) tensor + :param captions: (bsize, capt_len) tensor + """ + self.update(preds, captions) + return self.compute() + + def update(self, preds: Tensor, captions: Tensor) -> None: + preds = preds[preds == self._ignored_indexes] + captions = captions[captions == self._ignored_indexes] + + preds_vocab = torch.unique(preds) + captions_vocab = torch.unique(captions) + + if self._preds_vocab is None: + self._preds_vocab = preds_vocab + else: + self._preds_vocab = torch.unique(torch.cat(self._preds_vocab, preds_vocab)) + + if self._captions_vocab is None: + self._captions_vocab = captions_vocab + else: + self._captions_vocab = torch.unique( + torch.cat(self._captions_vocab, captions_vocab) + ) + + def compute(self) -> float: + if ( + self._preds_vocab is not None + and self._captions_vocab is not None + and len(self._captions_vocab) > 0 + ): + return len(self._preds_vocab) / len(self._captions_vocab) + else: + return 0.0 diff --git a/src/conette/metrics/classes/text_stats.py b/src/conette/metrics/classes/text_stats.py new file mode 100644 index 000000000..b76fc2c0b --- /dev/null +++ b/src/conette/metrics/classes/text_stats.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Callable, Union + +import torch + +from torch import Tensor + +from aac_metrics.classes.base import AACMetric + +from conette.metrics.functional.text_stats import text_stats + + +class TextStats(AACMetric): + full_state_update = False + higher_is_better = True + is_differentiable = False + + def __init__( + self, + return_all_scores: bool = True, + seed: Union[None, int, torch.Generator] = 123, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__() + self._return_all_scores = return_all_scores + self._seed = seed + self._tokenizer = tokenizer + + self._candidates = [] + self._mult_references = [] + + # Metric methods + def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return text_stats( + self._candidates, + self._mult_references, + self._return_all_scores, + self._seed, + self._tokenizer, + ) + + def get_output_names(self) -> tuple[str, ...]: + return ( + "sent_len.cands", + "sent_len.mrefs", + "sent_len.ratio", + "vocab_len.cands", + "vocab_len.mrefs_full", + "vocab_len.ratio_full", + "vocab_len.mrefs_avg", + "vocab_len.ratio_avg", + "vocab_coverage", + "vocab_in_ref_len", + "vocab_in_ref_ratio", + "empty_sents", + ) + + def reset(self) -> None: + self._candidates = [] + self._mult_references = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: list[list[str]], + ) -> None: + self._candidates += candidates + self._mult_references += mult_references diff --git a/src/conette/metrics/classes/wmd.py b/src/conette/metrics/classes/wmd.py new file mode 100644 index 000000000..246045736 --- /dev/null +++ b/src/conette/metrics/classes/wmd.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import math + +from typing import Any, Callable, Union + +from gensim import downloader +from gensim.downloader import load +from torch import Tensor +from torchmetrics import Metric + +from conette.metrics.functional.wmd import wmdistance + + +pylog = logging.getLogger(__name__) + + +class WMDistance(Metric): + """Word Mover Distance. + + Output is in range [0, +inf[. + """ + + is_differentiable = False + higher_is_better = False + full_state_update = False + + min_value = 0.0 + max_value = math.inf + + def __init__( + self, + return_all_scores: bool = True, + tokenizer: Callable[[str], list[str]] = str.split, + model_name: str = "word2vec-google-news-300", + verbose: int = 0, + ) -> None: + if verbose >= 2: + pylog.debug(f"Gensim data base dir: {downloader.BASE_DIR=}.") + + super().__init__() + self._return_all_scores = return_all_scores + self._tokenizer = tokenizer + self._model_name = model_name + self._model = load(model_name, return_path=False) + self._verbose = verbose + + if verbose >= 2: + path: str = load(model_name, return_path=True) # type: ignore + pylog.debug(f"Load gensim model {model_name=} from {path=}.") + + self._candidates = [] + self._mult_references = [] + + # Metric methods + def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return wmdistance( + self._candidates, + self._mult_references, + self._return_all_scores, + self._tokenizer, + self._model, + ) + + def get_output_names(self) -> tuple[str, ...]: + return ("wmd",) + + def reset(self) -> None: + self._candidates = [] + self._mult_references = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: list[list[str]], + ) -> None: + self._candidates += candidates + self._mult_references += mult_references + + def __getstate__(self) -> dict[str, Any]: + return { + "tokenizer": self._tokenizer, + "model_name": self._model_name, + "candidates": self._candidates, + "mult_references": self._mult_references, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self._tokenizer = state["tokenizer"] + self._model_name = state["model_name"] + self._candidates = state["candidates"] + self._mult_references = state["mult_references"] + + self._model = load(self._model_name) diff --git a/src/conette/metrics/cross_referencing.py b/src/conette/metrics/cross_referencing.py new file mode 100644 index 000000000..d27ae65bd --- /dev/null +++ b/src/conette/metrics/cross_referencing.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import sys + +from typing import Union + +import torch +import tqdm + +from torch import Tensor + +from aac_metrics.classes.base import AACMetric +from aac_metrics.utils.tokenization import preprocess_mult_sents +from conette.metrics.classes.all_metrics import AllMetrics + + +pylog = logging.getLogger(__name__) +MODES = ("random", "columns") + + +def compute_cross_referencing( + msents: list[list[str]], + mode: str = "random", + seed: Union[int, torch.Generator, None] = 1234, + preprocess: bool = True, + max_cross_refs: int = sys.maxsize, + metrics: Union[None, AACMetric] = None, + verbose: int = 1, +) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: + """Compute cross-referencing 'Human' scores for all metrics. + + Works only when all multiple sentences have the same number of sentences individually. + """ + if len(msents) == 0: + raise ValueError( + "Invalid number of mult sentences. (expected at least 1 set of sentences)" + ) + + n_sents_per_item_lst = list(map(len, msents)) + if not all(n_sents == n_sents_per_item_lst[0] for n_sents in n_sents_per_item_lst): + n_sents_set = list(set(n_sents_per_item_lst)) + raise ValueError( + f"Invalid n_sents list. (found different number of sentences per item with {n_sents_set=})" + ) + + n_sents_per_item = n_sents_per_item_lst[0] + del n_sents_per_item_lst + if n_sents_per_item <= 1: + raise ValueError(f"Cannot compute cross-referencing with {n_sents_per_item=}") + elif verbose >= 2: + pylog.debug(f"Found {n_sents_per_item=}.") + + if isinstance(seed, int): + gen = torch.Generator().manual_seed(seed) + else: + gen = seed + + if preprocess: + msents = preprocess_mult_sents(msents) + + if metrics is None: + metrics = AllMetrics(preprocess=False) + + max_cross_refs = min(n_sents_per_item, max_cross_refs) + all_outs_corpus = [] + all_outs_sents = [] + + for i in tqdm.trange(max_cross_refs, disable=verbose < 1): + if mode == "columns": + cands_i = [sents[i] for sents in msents] + mrefs_not_i = [ + [sent for j, sent in enumerate(sents) if j != i] for sents in msents + ] + + elif mode == "random": + indexes = torch.randint(0, n_sents_per_item, (len(msents),), generator=gen) + indexes = indexes.tolist() + cands_i = [sents[idx] for sents, idx in zip(msents, indexes)] + mrefs_not_i = [ + [sent for j, sent in enumerate(sents) if j != idx] + for sents, idx in zip(msents, indexes) + ] + + else: + raise ValueError(f"Invalid argument {mode=}. (expected one of {MODES})") + + outs_corpus, outs_sents = metrics( + cands_i, + mrefs_not_i, + ) + all_outs_corpus.append(outs_corpus) + all_outs_sents.append(outs_sents) + + return all_outs_corpus, all_outs_sents diff --git a/src/conette/metrics/functional/__init__.py b/src/conette/metrics/functional/__init__.py new file mode 100644 index 000000000..faa18be5b --- /dev/null +++ b/src/conette/metrics/functional/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- diff --git a/src/conette/metrics/functional/bert_score_mrefs.ign.py b/src/conette/metrics/functional/bert_score_mrefs.ign.py new file mode 100644 index 000000000..2435a4073 --- /dev/null +++ b/src/conette/metrics/functional/bert_score_mrefs.ign.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Callable, Optional, TypeVar, Union + +import torch + +from torch import nn, Tensor +from torchmetrics.functional.text.bert import bert_score, _DEFAULT_MODEL +from transformers.models.auto.modeling_auto import AutoModel +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers import logging as tfmers_logging + +from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list + + +T = TypeVar("T") + + +def bert_score_mrefs( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + model: Union[str, nn.Module] = _DEFAULT_MODEL, + tokenizer: Optional[Callable] = None, + device: Union[str, torch.device, None] = "auto", + batch_size: int = 32, + num_threads: int = 0, + max_length: int = 64, + reset_state: bool = True, + idf: bool = False, + verbose: int = 0, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + """BERTScore metric which supports multiple references. + + The implementation is based on the bert_score implementation of torchmetrics. + + - Paper: https://arxiv.org/pdf/1904.09675.pdf + + :param candidates: The list of sentences to evaluate. + :param mult_references: The list of list of sentences used as target. + :param return_all_scores: If True, returns a tuple containing the globals and locals scores. + Otherwise returns a scalar tensor containing the main global score. + defaults to True. + :param model: The model name or the instantiated model to use to compute token embeddings. + defaults to "roberta-large". + :param tokenizer: The fast tokenizer used to split sentences into words. + If None, use the tokenizer corresponding to the model argument. + defaults to None. + :param device: The PyTorch device used to run the BERT model. defaults to "auto". + :param batch_size: The batch size used in the model forward. + :param num_threads: A number of threads to use for a dataloader. defaults to 0. + :param max_length: Max length when encoding sentences to tensor ids. defaults to 64. + :param idf: Whether or not using Inverse document frequency to ponderate the BERTScores. defaults to False. + :param verbose: The verbose level. defaults to 0. + :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. + """ + + if isinstance(model, str): + if tokenizer is not None: + raise ValueError( + f"Invalid argument combinaison {model=} with {tokenizer=}." + ) + model, tokenizer = _load_model_and_tokenizer( + model, tokenizer, device, reset_state, verbose + ) + + elif isinstance(model, nn.Module): + if tokenizer is None: + raise ValueError( + f"Invalid argument combinaison {model=} with {tokenizer=}." + ) + + else: + raise ValueError( + f"Invalid argument type {type(model)=}. (expected str or nn.Module)" + ) + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + flat_mrefs, sizes = flat_list(mult_references) + duplicated_cands = duplicate_list(candidates, sizes) + + tfmers_verbosity = tfmers_logging.get_verbosity() + if verbose <= 1: + tfmers_logging.set_verbosity_error() + + sents_scores = bert_score( + duplicated_cands, + flat_mrefs, + model_name_or_path=None, + model=model, # type: ignore + user_tokenizer=tokenizer, + device=device, + batch_size=batch_size, + num_threads=num_threads, + verbose=verbose >= 3, + max_length=max_length, + idf=idf, + ) + if verbose <= 1: + # Restore previous verbosity level + tfmers_logging.set_verbosity(tfmers_verbosity) + + # sents_scores keys: "precision", "recall", "f1" + sents_scores = {k: unflat_list(v, sizes) for k, v in sents_scores.items()} # type: ignore + + if not return_all_scores: + sents_scores = {"f1": sents_scores["f1"]} + + dtype = torch.float32 + if len(sizes) > 0 and all(size == sizes[0] for size in sizes): + sents_scores = { + k: torch.as_tensor(v, dtype=dtype).mean(dim=1) + for k, v in sents_scores.items() + } + else: + sents_scores = { + k: torch.stack([torch.as_tensor(vi, dtype=dtype).mean() for vi in v]) + for k, v in sents_scores.items() + } + + sents_scores = {f"bert_score.{k}": v for k, v in sents_scores.items()} + sents_scores = {k: v.masked_fill(v.isnan(), 0.0) for k, v in sents_scores.items()} + + corpus_scores = {k: v.mean() for k, v in sents_scores.items()} + + if return_all_scores: + return corpus_scores, sents_scores + else: + return corpus_scores["bert_score.f1"] + + +def _load_model_and_tokenizer( + model: Union[str, nn.Module], + tokenizer: Optional[Callable], + device: Union[str, torch.device, None], + reset_state: bool, + verbose: int, +) -> tuple[nn.Module, Optional[Callable]]: + state = torch.random.get_rng_state() + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + if isinstance(model, str): + tfmers_verbosity = tfmers_logging.get_verbosity() + if verbose <= 1: + tfmers_logging.set_verbosity_error() + + # WARNING: tokenizer must be initialized BEFORE model to avoid connection errors + tokenizer = AutoTokenizer.from_pretrained(model) + model = AutoModel.from_pretrained(model) # type: ignore + + if verbose <= 1: + # Restore previous verbosity level + tfmers_logging.set_verbosity(tfmers_verbosity) + + model.eval() # type: ignore + model.to(device=device) # type: ignore + + if reset_state: + torch.random.set_rng_state(state) + + return model, tokenizer # type: ignore diff --git a/src/conette/metrics/functional/div_n.py b/src/conette/metrics/functional/div_n.py new file mode 100644 index 000000000..7c10426c9 --- /dev/null +++ b/src/conette/metrics/functional/div_n.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Callable, Union + +import torch + +from nltk.util import ngrams +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +def div_n( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + n: int = 1, + tokenizer: Callable[[str], list[str]] = str.split, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + tok_cands = list(map(tokenizer, candidates)) + del candidates + + dtype = torch.float64 + diversities = _compute_div_n(tok_cands, n, dtype) + diversity = diversities.mean() + + if return_all_scores: + corpus_scores = { + "div": diversity, + } + sents_scores = { + "div": diversities, + } + return corpus_scores, sents_scores + else: + return diversity + + +def _compute_div_n( + sentences: list[list[str]], + n: int, + dtype: torch.dtype, +) -> Tensor: + diversities = [len(set(ngrams(sent, n))) / len(sent) for sent in sentences] + diversities = torch.as_tensor(diversities, dtype=dtype) + return diversities diff --git a/src/conette/metrics/functional/diversity.py b/src/conette/metrics/functional/diversity.py new file mode 100644 index 000000000..6742963c7 --- /dev/null +++ b/src/conette/metrics/functional/diversity.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Callable, Union + +import torch + +from nltk.util import ngrams +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +def vocab_size( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + seed: Union[None, int, torch.Generator] = 123, + tokenizer: Callable[[str], list[str]] = str.split, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + outs: tuple[dict[str, Tensor], dict[str, Tensor]] = diversity( # type: ignore + candidates=candidates, + mult_references=mult_references, + return_all_scores=True, + n=1, + cumulative=False, + use_ngram_count=True, + seed=seed, + tokenizer=tokenizer, + ) + corpus_outs, sents_outs = outs + + if return_all_scores: + corpus_outs = { + ( + k.replace("sents_div1.", "sents_vocab.").replace( + "corpus_div1.", "corpus_vocab." + ) + ): v + for k, v in corpus_outs.items() + } + sents_outs = { + (k.replace("sents_div1.", "sents_vocab.")): v for k, v in sents_outs.items() + } + return corpus_outs, sents_outs + else: + return corpus_outs["corpus_div1.cands"] + + +def diversity( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + n: int = 1, + cumulative: bool = False, + use_ngram_count: bool = True, + seed: Union[None, int, torch.Generator] = 123, + tokenizer: Callable[[str], list[str]] = str.split, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + """Compute sentences and corpus n-grams diversities ratios from candidates and references with n-grams from 1 to n. + + :param candidates: The list of sentences to evaluate. + :param mult_references: The list of list of sentences used as target. + :param return_all_scores: If True, returns a tuple containing the globals and locals scores. + Otherwise returns a scalar tensor containing the main global score. + defaults to True. + """ + tok_cands, tok_mrefs = _diversity_update( + candidates, + mult_references, + tokenizer, + [], + [], + ) + return _diversity_compute( + tok_cands, tok_mrefs, return_all_scores, n, cumulative, use_ngram_count, seed + ) + + +def _diversity_compute( + tok_cands: list[list[str]], + tok_mrefs: list[list[list[str]]], + return_all_scores: bool, + n_max: int, + cumulative: bool, + use_ngram_count: bool, + seed: Union[None, int, torch.Generator], +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + if len(tok_mrefs) <= 0: + raise ValueError( + f"Invalid number of references. (found {len(tok_mrefs)} references)" + ) + + dtype = torch.float64 + sents_divs_cands = torch.empty((len(tok_cands), n_max), dtype=dtype) + sents_divs_mrefs = torch.empty((len(tok_mrefs), n_max), dtype=dtype) + + for i, (cand, refs) in enumerate(zip(tok_cands, tok_mrefs)): + div_cand = _sent_diversities(cand, n_max, cumulative, use_ngram_count, dtype) + refs_divs = [ + _sent_diversities(ref, n_max, cumulative, use_ngram_count, dtype) + for ref in refs + ] + if len(refs_divs) > 0: + div_refs = sum(refs_divs) / len(refs_divs) + else: + div_refs = torch.zeros((n_max,), dtype=dtype) + + sents_divs_cands[i] = div_cand + sents_divs_mrefs[i] = div_refs + + sents_divs_ratios = torch.where( + sents_divs_mrefs != 0.0, sents_divs_cands / sents_divs_mrefs, 0.0 + ) + corpus_div_cands = _corpus_diversities( + tok_cands, n_max, cumulative, use_ngram_count, dtype + ) + + if isinstance(seed, int): + generator = torch.Generator().manual_seed(seed) + else: + generator = seed + + max_n_refs_per_audio = max(len(refs) for refs in tok_mrefs) + corpus_div_mrefs_all = torch.empty((max_n_refs_per_audio, n_max), dtype=dtype) + + for i in range(max_n_refs_per_audio): + indexes = [ + int(torch.randint(0, len(refs), (), generator=generator).item()) + for refs in tok_mrefs + ] + popped_refs = [refs[idx] for idx, refs in zip(indexes, tok_mrefs)] + corpus_div_mrefs_i = _corpus_diversities( + popped_refs, n_max, cumulative, use_ngram_count, dtype + ) + corpus_div_mrefs_all[i] = corpus_div_mrefs_i + + # corpus_div_mrefs_all: (n_refs_per_audio, n_max) + corpus_div_mrefs = corpus_div_mrefs_all.mean(dim=0) + corpus_div_ratio = torch.where( + corpus_div_mrefs != 0.0, + corpus_div_cands / corpus_div_mrefs, + 0.0, + ) + + sents_div_cands = sents_divs_cands.mean(dim=0) + sents_div_mrefs = sents_divs_mrefs.mean(dim=0) + sents_div_ratio = sents_divs_ratios.mean(dim=0) + + if return_all_scores: + corpus_outs = {} + sents_outs = {} + for n in range(1, n_max + 1): + corpus_outs |= { + f"sents_div{n}.cands": sents_div_cands[n - 1], + f"sents_div{n}.mrefs": sents_div_mrefs[n - 1], + f"sents_div{n}.ratio": sents_div_ratio[n - 1], + f"corpus_div{n}.cands": corpus_div_cands[n - 1], + f"corpus_div{n}.mrefs": corpus_div_mrefs[n - 1], + f"corpus_div{n}.ratio": corpus_div_ratio[n - 1], + } + sents_outs |= { + f"sents_div{n}.cands": sents_divs_cands[:, n - 1], + f"sents_div{n}.mrefs": sents_divs_mrefs[:, n - 1], + f"sents_div{n}.ratio": sents_divs_ratios[:, n - 1], + } + + return corpus_outs, sents_outs + else: + return sents_div_ratio[-1] + + +def _diversity_update( + candidates: list[str], + mult_references: list[list[str]], + tokenizer: Callable[[str], list[str]], + prev_tok_cands: list[list[str]], + prev_tok_mrefs: list[list[list[str]]], +) -> tuple[list[list[str]], list[list[list[str]]]]: + new_tok_cands = list(map(tokenizer, candidates)) + new_tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references] + prev_tok_cands += new_tok_cands + prev_tok_mrefs += new_tok_mrefs + return prev_tok_cands, prev_tok_mrefs + + +def _sent_diversities( + sent: list[str], + n_max: int, + cumulative: bool, + use_ngram_count: bool, + dtype: torch.dtype, +) -> Tensor: + """ + :returns: tensor shape: (n_max,) + """ + diversities = torch.zeros((n_max,), dtype=dtype) + + if len(sent) == 0: + return diversities + + deno_count = torch.zeros((n_max,), dtype=dtype) + uniq_ngrams_count = torch.zeros((n_max,), dtype=dtype) + + for n in range(1, min(n_max, len(sent)) + 1): + ngrams_lst = list(ngrams(sent, n)) + ngrams_set = set(ngrams_lst) + + if use_ngram_count: + deno_count[n - 1] += len(ngrams_lst) + else: + deno_count[n - 1] += len(sent) + uniq_ngrams_count[n - 1] = len(ngrams_set) + + if cumulative: + uniq_ngrams_count = uniq_ngrams_count.cumsum(0) + deno_count = deno_count.cumsum(0) + + diversities = uniq_ngrams_count / deno_count.clamp(min=1.0) + arange = torch.arange(1, n_max + 1, dtype=dtype) + diversities = diversities / arange + + else: + diversities = uniq_ngrams_count / deno_count.clamp(min=1.0) + + return diversities + + +def _corpus_diversities( + sents: list[list[str]], + n_max: int, + cumulative: bool, + use_ngram_count: bool, + dtype: torch.dtype, +) -> Tensor: + """ + :returns: tensor shape: (n_max,) + """ + deno_count = torch.zeros((n_max,), dtype=dtype) + uniq_ngrams_sets = [set() for _ in range(n_max)] + + for sent in sents: + for n in range(1, min(n_max, len(sent)) + 1): + ngrams_lst = list(ngrams(sent, n)) + ngrams_set = set(ngrams_lst) + + if use_ngram_count: + deno_count[n - 1] += len(ngrams_lst) + else: + deno_count[n - 1] += len(sent) + uniq_ngrams_sets[n - 1] |= ngrams_set + + uniq_ngrams_count = torch.as_tensor([len(s) for s in uniq_ngrams_sets], dtype=dtype) + + if cumulative: + uniq_ngrams_count = uniq_ngrams_count.cumsum(0) + deno_count = deno_count.cumsum(0) + diversities = uniq_ngrams_count / deno_count.clamp(min=1.0) + arange = torch.arange(1, n_max + 1, dtype=dtype) + diversities = diversities / arange + + else: + diversities = uniq_ngrams_count / deno_count.clamp(min=1.0) + + return diversities diff --git a/src/conette/metrics/functional/jaccard.py b/src/conette/metrics/functional/jaccard.py new file mode 100644 index 000000000..daffe8792 --- /dev/null +++ b/src/conette/metrics/functional/jaccard.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Callable, Optional, Union + +import torch + +from nltk.stem import StemmerI +from torch import Tensor + + +def jaccard( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + stemmer: Optional[StemmerI] = None, + tokenizer: Callable[[str], list[str]] = str.split, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + """Compute Jaccard score.""" + + tok_cands = list(map(tokenizer, candidates)) + tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references] + del candidates, mult_references + + jaccard_scores = torch.empty((len(tok_cands),), dtype=torch.float64) + + for i, (tok_cand, tok_refs) in enumerate(zip(tok_cands, tok_mrefs)): + if stemmer is not None: + tok_cand = [stemmer.stem(token) for token in tok_cand] + tok_refs = [[stemmer.stem(token) for token in ref] for ref in tok_refs] + + tok_cand = set(tok_cand) + tok_refs = [set(ref) for ref in tok_refs] + + similarities = [] + for tok_ref in tok_refs: + similarity = len(tok_cand.intersection(tok_ref)) / len( + tok_cand.union(tok_ref) + ) + similarities.append(similarity) + + if len(similarities) > 0: + sim = sum(similarities) / len(similarities) + else: + sim = 0.0 + jaccard_scores[i] = sim + + jaccard_score = jaccard_scores.mean() + + if return_all_scores: + corpus_scores = { + "jaccard": jaccard_score, + } + sents_scores = { + "jaccard": jaccard_scores, + } + return corpus_scores, sents_scores + else: + return jaccard_score diff --git a/src/conette/metrics/functional/new_words.py b/src/conette/metrics/functional/new_words.py new file mode 100644 index 000000000..f18d2fd8e --- /dev/null +++ b/src/conette/metrics/functional/new_words.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Callable, Iterable, Union + +import torch + +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +def new_words( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + train_vocab: Iterable[str] = (), + tokenizer: Callable[[str], list[str]] = str.split, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + tok_cands = list(map(tokenizer, candidates)) + del candidates, mult_references + + train_vocab = dict.fromkeys(train_vocab) + + dtype = torch.float64 + new_words_lst = [set(tokens).difference(train_vocab) for tokens in tok_cands] + new_words_counts = torch.as_tensor(list(map(len, new_words_lst)), dtype=dtype) + new_words_total = new_words_counts.mean() + + if return_all_scores: + corpus_scores = { + "new_words": new_words_total, + } + sents_scores = { + "new_words": new_words_counts, + } + return corpus_scores, sents_scores + else: + return new_words_total diff --git a/src/conette/metrics/functional/self_bleu.py b/src/conette/metrics/functional/self_bleu.py new file mode 100644 index 000000000..a623b85ef --- /dev/null +++ b/src/conette/metrics/functional/self_bleu.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import copy + +from typing import Callable, Optional, Union + +import torch + +from torch import Tensor + +from aac_metrics.functional.bleu import bleu + + +def self_bleu( + sentences: list[str], + n: int = 4, + max_refs: Optional[int] = None, + generator: Union[None, int, torch.Generator] = None, + tokenizer: Callable[[str], list[str]] = str.split, +) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + if isinstance(generator, int): + generator = torch.Generator().manual_seed(generator) + if max_refs is not None and max_refs >= len(sentences) - 1: + raise ValueError( + f"Invalid argument {max_refs=}. (found {max_refs=} >= {len(sentences)-1})" + ) + + self_bleu_scores = [] + for i, sentence in enumerate(sentences): + if max_refs is None: + other_candidates = copy.deepcopy(sentences) + other_candidates.pop(i) + else: + continue_ = True + indexes = [] + while continue_: + indexes = torch.randperm(len(sentences), generator=generator)[ + :max_refs + ].tolist() + continue_ = i in indexes + other_candidates = [sentences[idx] for idx in indexes] + + score = bleu( + [sentence], + [other_candidates], + n=n, + tokenizer=tokenizer, + ) + self_bleu_scores.append(score) + + dtype = torch.float64 + self_bleu_scores = torch.as_tensor(self_bleu_scores, dtype=dtype) + self_bleu_score = self_bleu_scores.mean() + + corpus_scores = { + f"self_bleu_{n}": self_bleu_score, + } + sents_scores = { + f"self_bleu_{n}": self_bleu_scores, + } + return corpus_scores, sents_scores diff --git a/src/conette/metrics/functional/text_stats.py b/src/conette/metrics/functional/text_stats.py new file mode 100644 index 000000000..b55f8e91d --- /dev/null +++ b/src/conette/metrics/functional/text_stats.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from collections import Counter +from typing import Callable, Union + +import torch + +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +def text_stats( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + seed: Union[None, int, torch.Generator] = 123, + tokenizer: Callable[[str], list[str]] = str.split, +) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + """Compute text statistics about sentences lengths and vocab sizes.""" + + if len(mult_references) <= 0: + raise ValueError( + f"Invalid number of references. (found {len(mult_references)} references)" + ) + + tok_cands = list(map(tokenizer, candidates)) + tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references] + del candidates, mult_references + + sent_lens_cands = list(map(len, tok_cands)) + sent_lens_mrefs = [sum(map(len, refs)) / len(refs) for refs in tok_mrefs] + + dtype = torch.float64 + sent_lens_cands = torch.as_tensor(sent_lens_cands, dtype=dtype) + sent_lens_mrefs = torch.as_tensor(sent_lens_mrefs, dtype=dtype) + sent_lens_ratios = sent_lens_cands / sent_lens_mrefs + + global_cands_counter = Counter(token for cand in tok_cands for token in cand) + global_mrefs_counter = Counter( + token for refs in tok_mrefs for ref in refs for token in ref + ) + + total_mrefs_tokens = max(sum(global_mrefs_counter.values()), 1) + vocab_coverage = sum( + global_mrefs_counter[token] / total_mrefs_tokens + for token in global_cands_counter.keys() + ) + + cands_vocab_in_ref = [ + token + for token in global_cands_counter.keys() + if token in global_mrefs_counter.keys() + ] + vocab_in_ref_len = torch.as_tensor(len(cands_vocab_in_ref), dtype=dtype) + vocab_in_ref_ratio = vocab_in_ref_len / len(global_cands_counter) + + vocab_len_cands = torch.as_tensor(len(global_cands_counter), dtype=dtype) + vocab_len_mrefs_full = torch.as_tensor(len(global_mrefs_counter), dtype=dtype) + vocab_len_ratio_full = vocab_len_cands / vocab_len_mrefs_full + vocab_coverage = torch.as_tensor(vocab_coverage, dtype=dtype) + + if isinstance(seed, int): + generator = torch.Generator().manual_seed(seed) + else: + generator = seed + + max_n_refs_per_audio = max(len(refs) for refs in tok_mrefs) + vocab_len_lst = torch.empty((max_n_refs_per_audio,), dtype=dtype) + + for i in range(max_n_refs_per_audio): + indexes = [ + int(torch.randint(0, len(refs), (), generator=generator).item()) + for refs in tok_mrefs + ] + popped_refs = [refs[idx] for idx, refs in zip(indexes, tok_mrefs)] + vocab_len = len(set(token for ref in popped_refs for token in ref)) + vocab_len_lst[i] = vocab_len + + vocab_len_mrefs_avg = vocab_len_lst.mean() + vocab_len_ratio_avg = vocab_len_cands / vocab_len_mrefs_avg + + empty_sents = torch.as_tensor( + [(1 if len(cand) == 0 else 0) for cand in tok_cands], dtype=dtype + ) + empty_sents_rate = empty_sents.mean() + + if return_all_scores: + sents_scores = { + "sent_len_cands": sent_lens_cands, + "sent_len.mrefs": sent_lens_mrefs, + "sent_len.ratio": sent_lens_ratios, + "empty_sents": empty_sents, + } + corpus_scores = { + "sent_len.cands": sent_lens_cands.mean(), + "sent_len.mrefs": sent_lens_mrefs.mean(), + "sent_len.ratio": sent_lens_ratios.mean(), + "vocab_len.cands": vocab_len_cands, + "vocab_len.mrefs_full": vocab_len_mrefs_full, + "vocab_len.ratio_full": vocab_len_ratio_full, + "vocab_len.mrefs_avg": vocab_len_mrefs_avg, + "vocab_len.ratio_avg": vocab_len_ratio_avg, + "vocab_coverage": vocab_coverage, + "vocab_in_ref_len": vocab_in_ref_len, + "vocab_in_ref_ratio": vocab_in_ref_ratio, + "empty_sents": empty_sents_rate, + "sent_len.cands.min": sent_lens_cands.min(), + "sent_len.cands.max": sent_lens_cands.max(), + } + + return corpus_scores, sents_scores + else: + raise ValueError( + f"Cannot compute text_stats() function with {return_all_scores=}." + ) diff --git a/src/conette/metrics/functional/torch_cider_d.py b/src/conette/metrics/functional/torch_cider_d.py new file mode 100644 index 000000000..0ffe21277 --- /dev/null +++ b/src/conette/metrics/functional/torch_cider_d.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math + +from collections import defaultdict, Counter +from typing import Any, Mapping, Union + +import torch + +from torch import Tensor + + +class FrozenHashableTensor(Tensor): + def __init__(self, x: Tensor) -> None: + super().__init__() + self.set_(x.storage()) + self._hash = self.hash() + + def hash(self) -> int: + arange = torch.arange(self.nelement(), dtype=self.dtype, device=self.device) + x = self.flatten() * arange + return int(x.sum().item()) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, FrozenHashableTensor) + and self.shape == other.shape + and bool(torch.eq(self, other).all().item()) + ) + + def __hash__(self) -> int: + return self._hash + + +def torch_cider_d( + candidates: Tensor, + mult_references: Tensor, + return_all_scores: bool = True, + n: int = 4, + bos_id: int = 1, + eos_id: int = 2, + sigma: float = 6.0, + return_tfidf: bool = False, +) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Any]]]: + """ + :param n: set cider to sum over 1 to 4-grams + :param sigma: set the standard deviation parameter for gaussian penalty + """ + cooked_cands, cooked_mrefs = _torch_cider_d_update( + candidates, + mult_references, + n, + bos_id, + eos_id, + [], + [], + ) + return _torch_cider_d_compute( + cooked_cands, + cooked_mrefs, + return_all_scores, + n, + sigma, + return_tfidf, + ) + + +def _torch_cider_d_update( + candidates: Tensor, + mult_references: Tensor, + n: int, + bos_id: int, + eos_id: int, + prev_cooked_cands: list, + prev_cooked_mrefs: list, +) -> tuple[list, list]: + if len(candidates) != len(mult_references): + raise ValueError( + f"Invalid number of candidates and references. (found {len(candidates)=} != {len(mult_references)=})" + ) + new_cooked_mrefs = [ + [_cook_sentence(ref, n, bos_id, eos_id) for ref in refs] + for refs in mult_references + ] + new_cooked_cands = [_cook_sentence(cand, n, bos_id, eos_id) for cand in candidates] + prev_cooked_cands += new_cooked_cands + prev_cooked_mrefs += new_cooked_mrefs + return prev_cooked_cands, prev_cooked_mrefs + + +def _torch_cider_d_compute( + cooked_cands: list[Counter[FrozenHashableTensor]], + cooked_mrefs: list[list[Counter[FrozenHashableTensor]]], + return_all_scores: bool, + n: int, + sigma: float, + return_tfidf: bool, +) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Any]]]: + if len(cooked_cands) <= 1: + raise ValueError( + f"CIDEr metric does not support less than 2 candidates with 2 references. (found {len(cooked_cands)} candidates, but expected > 1)" + ) + # compute idf + document_frequency = _compute_doc_freq(cooked_mrefs) + # compute log reference length + log_ref_len = math.log(float(len(cooked_mrefs))) + # sanity check: assert to check document frequency + assert len(cooked_cands) >= max(document_frequency.values()) + # compute cider score + cider_d_scores, tfidf_lst = _compute_cider( + cooked_cands, + cooked_mrefs, + document_frequency, + log_ref_len, + n, + sigma, + ) + cider_d_score = cider_d_scores.mean() + + if return_all_scores: + cider_d_global_outs = { + "cider_d": cider_d_score, + } + cider_d_local_outs = { + "cider_d": cider_d_scores, + } + if return_tfidf: + cider_d_local_outs["tfidf_lst"] = tfidf_lst # type: ignore + + return cider_d_global_outs, cider_d_local_outs + else: + return cider_d_score + + +def _cook_sentence( + sentence: Tensor, + n: int, + bos_id: int, + eos_id: int, +) -> Counter[FrozenHashableTensor]: + if sentence[0] == bos_id: + sentence = sentence[1:] + + if eos_id in sentence: + eos_pos = sentence.eq(eos_id).int().argmax() + sentence = sentence[:eos_pos] + + sentence = FrozenHashableTensor(sentence) + + counter = Counter() + for k in range(1, n + 1): + for i in range(len(sentence) - k + 1): + ngram = sentence[i : i + k] + counter[ngram] += 1 + + return counter + + +def _compute_doc_freq( + cooked_mrefs: list[list[Counter[FrozenHashableTensor]]], +) -> Counter[FrozenHashableTensor]: + """ + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + """ + document_frequency = Counter() + for cooked_refs in cooked_mrefs: + # refs, k ref captions of one image + for ngram in set( + ngram for cooked_ref in cooked_refs for ngram in cooked_ref.keys() + ): + document_frequency[ngram] += 1 + # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + return document_frequency + + +def _counter_to_vec( + counters: Mapping[FrozenHashableTensor, int], + log_ref_len: float, + n: int, + document_frequency: Counter[FrozenHashableTensor], +) -> tuple[list[defaultdict], Tensor, int]: + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(n)] + length = 0 + norm = torch.zeros((n,), dtype=torch.float64) + + for ngram, term_freq in counters.items(): + # give word count 1 if it doesn't appear in reference corpus + log_df = math.log(max(1.0, document_frequency[ngram])) + + # ngram index + ni = len(ngram) - 1 + + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[ni][ngram] = float(term_freq) * (log_ref_len - log_df) + + # compute norm for the vector. the norm will be used for computing similarity + norm[ni] += pow(vec[ni][ngram], 2) + + if ni == 1: + length += term_freq + + norm = torch.sqrt(norm) + return vec, norm, length + + +def _similarity( + cand_vec: list[defaultdict], + ref_vec: list[defaultdict], + cand_norm: Tensor, + ref_norm: Tensor, + cand_len: int, + ref_len: int, + n: int, + sigma: float, +) -> Tensor: + """ + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + """ + # measure consine similarity + val = torch.zeros((n,), dtype=torch.float64) + + for ni in range(n): + # ngram + for ngram, count in cand_vec[ni].items(): + # vrama91 : added clipping + val[ni] += min(count, ref_vec[ni][ngram]) * ref_vec[ni][ngram] + + norms = cand_norm * ref_norm + norms[norms == 0.0] = 1.0 + val = val / norms + + # vrama91: added a length based gaussian penalty + delta = float(cand_len - ref_len) + val = val * math.e ** (-(delta**2) / (2 * sigma**2)) + + return val + + +def _compute_cider( + cooked_cands: list[Counter[FrozenHashableTensor]], + cooked_mrefs: list[list[Counter[FrozenHashableTensor]]], + document_frequency: Counter, + log_ref_len: float, + n: int, + sigma: float, + scale: float = 10.0, +) -> tuple[Tensor, list[tuple]]: + scores = torch.empty((len(cooked_cands),), dtype=torch.float64) + tfidf_lst = [] + + for i, (cooked_cand, cooked_refs) in enumerate(zip(cooked_cands, cooked_mrefs)): + # compute vector for test captions + vec, norm, length = _counter_to_vec( + cooked_cand, log_ref_len, n, document_frequency + ) + # compute vector for ref captions + ngrams_scores = torch.zeros((n,), dtype=torch.float64) + vec_refs = [] + for ref in cooked_refs: + vec_ref, norm_ref, length_ref = _counter_to_vec( + ref, log_ref_len, n, document_frequency + ) + vec_refs.append(vec_ref) + ngrams_scores += _similarity( + vec, vec_ref, norm, norm_ref, length, length_ref, n, sigma + ) + # change by vrama91 - mean of ngram scores, instead of sum + # divide by number of mult_references + agg_ngrams_scores = ngrams_scores.mean() / len(cooked_refs) + # multiply score by 10 + agg_ngrams_scores *= scale + # append score of an image to the score list + scores[i] = agg_ngrams_scores + tfidf_lst.append((vec, vec_refs)) + + return scores, tfidf_lst diff --git a/src/conette/metrics/functional/wmd.py b/src/conette/metrics/functional/wmd.py new file mode 100644 index 000000000..c84440d80 --- /dev/null +++ b/src/conette/metrics/functional/wmd.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Callable, Union + +import torch + +from gensim.downloader import load +from gensim.models.keyedvectors import KeyedVectors +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +def wmdistance( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + tokenizer: Callable[[str], list[str]] = str.split, + model: Union[str, KeyedVectors] = "word2vec-google-news-300", # type: ignore + verbose: int = 0, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + if isinstance(model, str): + model: KeyedVectors = load(model, return_path=False) # type: ignore + if verbose >= 2: + path: str = load(model, return_path=True) # type: ignore + pylog.debug(f"Load gensim model from {path=}.") + + dtype = torch.float64 + tok_cands = list(map(tokenizer, candidates)) + tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references] + + distances = torch.zeros((len(tok_cands),), dtype=torch.float64) + + for i, (tok_cand, tok_refs) in enumerate(zip(tok_cands, tok_mrefs)): + distances_i = [model.wmdistance(tok_cand, tok_ref) for tok_ref in tok_refs] + distances_i = torch.as_tensor(distances_i, dtype=dtype) + distances[i] = distances_i.mean() + + distance = distances.mean() + + if return_all_scores: + corpus_scores = { + "wmd": distance, + } + sents_scores = { + "wmd": distances, + } + return corpus_scores, sents_scores + else: + return distance diff --git a/src/conette/metrics/retrieval.py b/src/conette/metrics/retrieval.py new file mode 100644 index 000000000..cc95a5d66 --- /dev/null +++ b/src/conette/metrics/retrieval.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Any, Callable, Union + +import numpy as np +import torch +import tqdm + +from torch import Tensor + +from conette.nn.functional.misc import can_be_stacked + + +def retrieval_metrics( + scores: Tensor, + is_matching: Union[Callable[[int, int], bool], np.ndarray, Tensor], + return_all_scores: bool = True, + return_retrieved_indexes: bool = False, + limit_relevant_with_k: bool = False, + consider_only_best: bool = True, + verbose: int = 0, +) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: + row_names = list(map(str, range(scores.shape[0]))) + col_names = list(map(str, range(scores.shape[1]))) + + if isinstance(is_matching, Callable): + matching_matrix = _matching_fn_to_matrix(is_matching, scores.shape) + else: + matching_matrix = is_matching + assert ( + matching_matrix.shape == scores.shape + ), f"{matching_matrix.shape=} != {scores.shape=}" + + qid2items = _matrix_to_qid2items(scores, row_names, col_names, matching_matrix) + qid_mAP10s, qid_R1s, qid_R5s, qid_R10s, retrieved_indexes = _measure( + qid2items, limit_relevant_with_k, consider_only_best, verbose + ) + + dtype = torch.float64 + qid_mAP10s = torch.as_tensor(qid_mAP10s, dtype=dtype) + qid_R1s = torch.as_tensor(qid_R1s, dtype=dtype) + qid_R5s = torch.as_tensor(qid_R5s, dtype=dtype) + qid_R10s = torch.as_tensor(qid_R10s, dtype=dtype) + + retrieval_outs_sents: dict[str, Any] = { + "mAP10": qid_mAP10s, + "R1": qid_R1s, + "R5": qid_R5s, + "R10": qid_R10s, + } + retrieval_outs_corpus = { + name: c_scores.mean() for name, c_scores in retrieval_outs_sents.items() + } + + for n in (1, 5): + uniq_top_n_count = torch.as_tensor( + len(set(scores.argsort(dim=1, descending=True)[:, :n].flatten().tolist())), + dtype=dtype, + ) + uniq_top_n_max = torch.as_tensor(scores.shape, dtype=dtype).min() + uniq_top_n_rate = uniq_top_n_count / uniq_top_n_max + retrieval_outs_corpus[f"uniq_top{n}_count"] = uniq_top_n_count + retrieval_outs_corpus[f"uniq_top{n}_max"] = uniq_top_n_max + retrieval_outs_corpus[f"uniq_top{n}_rate"] = uniq_top_n_rate + + sorted_indexes = scores.argsort(dim=1, descending=True) + mean_ranks = torch.empty((len(sorted_indexes),), dtype=dtype) + ranks = [] + + for i, indexes in enumerate(sorted_indexes): + ranks_i = torch.where(matching_matrix[i][indexes])[0] + # ranks of shape (n,) + mean_rank_i = (ranks_i / matching_matrix.shape[1]).mean() + mean_ranks[i] = mean_rank_i + + ranks.append(ranks_i) + + mean_rank = mean_ranks.mean() + med_rank = mean_ranks.median() + + retrieval_outs_sents["mean_rank"] = mean_ranks + retrieval_outs_corpus["mean_rank"] = mean_rank + retrieval_outs_corpus["med_rank"] = med_rank + + if can_be_stacked(ranks): + ranks = torch.stack(ranks) + retrieval_outs_sents["rank"] = ranks + + if return_retrieved_indexes: + retrieval_outs_sents["retrieved_indexes"] = torch.as_tensor( + retrieved_indexes, dtype=torch.long + ) + + if return_all_scores: + return retrieval_outs_corpus, retrieval_outs_sents + else: + return retrieval_outs_sents["mAP10"] + + +def _matching_fn_to_matrix( + is_matching: Callable[[int, int], bool], size: tuple[int, int] +) -> Tensor: + matching_matrix = torch.full(size, False, dtype=torch.bool) + for i in range(size[0]): + for j in range(size[1]): + matching_matrix[i, j] = is_matching(i, j) + return matching_matrix + + +def _matrix_to_qid2items( + scores: Tensor, + row_names: list[str], + col_names: list[str], + is_matching: Union[np.ndarray, Tensor], +) -> dict[str, list[tuple[str, float, bool]]]: + assert tuple(scores.shape) == (len(row_names), len(col_names)) + qid2items = {} + for i, name_i in enumerate(row_names): + qid2items[name_i] = [ + (name_j, scores[i, j].item(), is_matching[i, j]) + for j, name_j in enumerate(col_names) + ] + return qid2items + + +def _measure( + qid2items: dict[Any, list[tuple[Any, float, bool]]], + limit_relevant_with_k: bool, + consider_only_best: bool, + verbose: int, +) -> tuple[list, list, list, list, list]: + """Retrieval metrics over sample queries + + i.e., recall@{1, 5, 10}, mAP@10. + BASED on https://github.com/xieh97/dcase2023-audio-retrieval/blob/master/postprocessing/xmodal_retrieval.py#L32 + """ + mAP_top = 10 + + qid_R1s = [] + qid_R5s = [] + qid_R10s = [] + qid_mAP10s = [] + retrieved_indexes = [] + + for items in tqdm.tqdm(qid2items.values(), disable=verbose < 2): + scores = np.array([i[1] for i in items]) + targets = np.array([i[2] for i in items]) + + # assert ( + # targets.sum() == 1 + # ) # DEBUG: for text-to-audio, we expect only 1 audio per query + + desc_indices = np.argsort(scores, axis=-1)[::-1] + targets = np.take_along_axis(arr=targets, indices=desc_indices, axis=-1) + + retrieved_indexes.append(desc_indices.tolist()) + + # Recall at cutoff K + targets_sum = np.sum(targets, dtype=float) + + top1_sum = np.sum(targets[:1], dtype=float) + top5_sum = np.sum(targets[:5], dtype=float) + top10_sum = np.sum(targets[:10], dtype=float) + + if limit_relevant_with_k and consider_only_best: + raise ValueError( + f"Incompatible arguments values {limit_relevant_with_k=} and {consider_only_best=}. (one or both must be False)" + ) + + elif limit_relevant_with_k: + recall_at_1 = top1_sum / min(targets_sum, 1) + recall_at_5 = top5_sum / min(targets_sum, 5) + recall_at_10 = top10_sum / min(targets_sum, 10) + + elif consider_only_best: + recall_at_1 = min(top1_sum, 1) + recall_at_5 = min(top5_sum, 1) + recall_at_10 = min(top10_sum, 1) + + else: # default behaviour + recall_at_1 = top1_sum / targets_sum + recall_at_5 = top5_sum / targets_sum + recall_at_10 = top10_sum / targets_sum + + qid_R1s.append(recall_at_1) + qid_R5s.append(recall_at_5) + qid_R10s.append(recall_at_10) + + # Mean average precision + positions = np.arange(1, mAP_top + 1, dtype=float)[targets[:mAP_top] > 0] + if len(positions) > 0: + precisions = np.divide( + np.arange(1, len(positions) + 1, dtype=float), positions + ) + avg_precision = np.sum(precisions, dtype=float) / targets_sum + qid_mAP10s.append(avg_precision) + else: + qid_mAP10s.append(0.0) + + return qid_mAP10s, qid_R1s, qid_R5s, qid_R10s, retrieved_indexes diff --git a/src/conette/nn/cnext_ckpt_utils.py b/src/conette/nn/cnext_ckpt_utils.py new file mode 100644 index 000000000..30a3d4dfa --- /dev/null +++ b/src/conette/nn/cnext_ckpt_utils.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os.path as osp + +from typing import Union + +import torch + +from torch import Tensor + +from conette.transforms.audioset_labels import load_audioset_mapping + + +pylog = logging.getLogger(__name__) + + +# Zenodo link : https://zenodo.org/record/8020843/ +# Hash type : md5 +CNEXT_PRETRAINED_URLS = { + "cnext_nobl": { + "model": "ConvNeXt", + "url": "https://zenodo.org/record/8020843/files/convnext_tiny_471mAP.pth?download=1", + "hash": "e069ecd1c7b880268331119521c549f2", + "fname": "convnext_tiny_471mAP.pth", + }, + "cnext_bl": { + "model": "ConvNeXt", + "url": "https://zenodo.org/record/8020843/files/convnext_tiny_465mAP_BL_AC_70kit.pth?download=1", + "hash": "0688ae503f5893be0b6b71cb92f8b428", + "fname": "convnext_tiny_465mAP_BL_AC_70kit.pth", + }, +} + + +def cnext_get_ckpt_dir_path() -> str: + """Return the path to the directory containing CNEXT checkpoints files.""" + return osp.join(torch.hub.get_dir(), "checkpoints") + + +def cnext_get_ckpt_path(model_name: str) -> str: + """Return the path to the CNEXT checkpoint file.""" + if model_name not in CNEXT_PRETRAINED_URLS: + raise ValueError( + f"Invalid argument {model_name=}. (expected one of {tuple(CNEXT_PRETRAINED_URLS.keys())})" + ) + + fname = CNEXT_PRETRAINED_URLS[model_name]["fname"] + fpath = osp.join(cnext_get_ckpt_dir_path(), fname) + return fpath + + +def cnext_load_state_dict( + model_name_or_path: str, + device: Union[str, torch.device, None] = None, + offline: bool = False, + verbose: int = 0, +) -> dict[str, Tensor]: + """Load CNEXT state_dict weights. + + :param model_name_or_path: Model name (case sensitive) or path to CNEXT checkpoint file. + :param device: Device of checkpoint weights. defaults to None. + :param offline: If False, the checkpoint from a model name will be automatically downloaded. + defaults to False. + :param verbose: Verbose level. defaults to 0. + :returns: State dict of model weights. + """ + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + if osp.isfile(model_name_or_path): + model_path = model_name_or_path + else: + try: + model_path = cnext_get_ckpt_path(model_name_or_path) + except ValueError: + raise ValueError( + f"Invalid argument {model_name_or_path=}. (expected a path to a checkpoint file or a model name in {tuple(CNEXT_PRETRAINED_URLS.keys())})" + ) + + if not osp.isfile(model_path): + if offline: + raise FileNotFoundError( + f"Cannot find checkpoint model file in '{model_path}' with mode {offline=}." + ) + else: + cnext_download_ckpt(model_name_or_path, verbose) + + del model_name_or_path + + data = torch.load(model_path, map_location=device) + state_dict = data["model"] + + if verbose >= 1: + test_map = data.get("test_mAP", "unknown") + pylog.info( + f"Loading encoder weights from '{model_path}'... (with test_mAP={test_map})" + ) + + return state_dict + + +def cnext_download_ckpt(model_name: str, verbose: int = 0) -> None: + """Download CNEXT checkpoint file.""" + fpath = cnext_get_ckpt_path(model_name) + url = CNEXT_PRETRAINED_URLS[model_name]["url"] + torch.hub.download_url_to_file(url, fpath, progress=verbose >= 1) + + +def probs_to_binarized( + probs: Tensor, + threshold: Union[float, Tensor], +) -> Tensor: + if probs.ndim != 2: + raise ValueError( + f"Invalid argument probs. (expected a batch of probabilities of shape (N, n_classes))." + ) + nb_classes = probs.shape[1] + + if isinstance(threshold, Tensor) and threshold.ndim == 1: + threshold = threshold.item() + + if isinstance(threshold, (float, int)): + threshold = torch.full( + (nb_classes,), threshold, dtype=torch.float, device=probs.device + ) + else: + if threshold.shape[1] != nb_classes: + raise ValueError("Invalid argument threshold.") + threshold = threshold.to(device=probs.device) + + binarized = probs >= threshold + return binarized + + +def binarized_to_indices( + binarized: Tensor, +) -> list[list[int]]: + preds = [] + for binarized_i in binarized: + preds_i = torch.where(binarized_i)[0].tolist() + preds.append(preds_i) + return preds + + +def probs_to_indices( + probs: Tensor, + threshold: Union[float, Tensor], +) -> list[list[int]]: + binarized = probs_to_binarized(probs, threshold) + preds = binarized_to_indices(binarized) + return preds + + +def probs_to_labels( + probs: Tensor, + threshold: Union[float, Tensor], + audioset_indices_fpath: str, +) -> list[list[str]]: + indices = probs_to_indices(probs, threshold) + labels = indices_to_labels(indices, audioset_indices_fpath) + return labels + + +def indices_to_labels( + indices: Union[list[list[int]], list[Tensor]], + audioset_indices_fpath: str, +) -> list[list[str]]: + name_to_idx = load_audioset_mapping() + idx_to_name = {idx: name for name, idx in name_to_idx.items()} + + labels = [] + for indices_i in indices: + names = [idx_to_name[idx] for idx in indices_i] # type: ignore + labels.append(names) + return labels diff --git a/src/conette/nn/decoding/greedy.py b/src/conette/nn/decoding/greedy.py new file mode 100644 index 000000000..fae45619d --- /dev/null +++ b/src/conette/nn/decoding/greedy.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math + +from typing import Any, Optional + +import torch + +from torch import Tensor + +from conette.nn.decoding.common import AACDecoder +from conette.nn.functional.label import ints_to_multihots +from conette.nn.functional.mask import generate_square_subsequent_mask + + +@torch.no_grad() +def greedy_search( + decoder: AACDecoder, + pad_id: int, + bos_id: int, + eos_id: int, + vocab_size: int, + frame_embs: Tensor, + frame_embs_pad_mask: Tensor, + min_pred_size: int = 0, + max_pred_size: int = 20, + forbid_rep_mask: Optional[Tensor] = None, +) -> Tensor: + """Greedy search for Transformer decoder. + + :param decoder: The decoder part of the model. + :param pad_id: Padding token id. + :param bos_id: Begin-of-Sentence token id. + :param eos_id: End-of-Sentence token id. + :param vocab_size: Vocabulary size of the model. + :param frame_embs: (bsize, frame_emb_size, n_frames) + :param frame_embs_pad_mask: (bsize, audio_seq_size) + :param min_pred_size: Minimal number of tokens in the output sentences. defaults to 0. + :param max_pred_size: Maximal number of tokens in the output sentences. defaults to 20. + :param forbid_rep_mask: TODO + :returns: logits of shape (bsize, vocab_size, max_pred_size or less) + """ + assert min_pred_size >= 0 + + bsize = frame_embs.shape[0] + device = frame_embs.device + bkwds: dict[str, Any] = dict(dtype=torch.bool, device=device) + fkwds: dict[str, Any] = dict(dtype=frame_embs.dtype, device=device) + ikwds: dict[str, Any] = dict(dtype=torch.long, device=device) + + # (bsize, emb_size, n_frames) -> (n_frames, bsize, emb_size) + frame_embs = frame_embs.permute(2, 0, 1) + + batch_idxs = torch.arange(bsize, **ikwds) + + preds = torch.full( + (bsize, max_pred_size + 1), + pad_id, + **ikwds, + ) + preds[:, 0] = bos_id + + global_logits_out = torch.full( + (bsize, vocab_size, max_pred_size), + -math.inf, + **fkwds, + ) + global_logits_out[:, pad_id, :] = 0 + + caps_in_sq_mask = generate_square_subsequent_mask(max_pred_size, device) + if forbid_rep_mask is None: + forbid_rep_mask = torch.zeros((vocab_size,), **bkwds) + use_forbid_rep = forbid_rep_mask.eq(True).any() + + # unfinished sentence mask + # unfinished_mask = torch.full((bsize,), True, device=device, dtype=torch.bool) + pred_size = max_pred_size + + for i in range(max_pred_size): + preds_in_i = preds[:, : i + 1].transpose(0, 1) + caps_in_sq_mask_i = caps_in_sq_mask[: i + 1, : i + 1] + + full_logits_i = decoder( + frame_embs.contiguous(), + frame_embs_pad_mask.contiguous(), + preds_in_i.contiguous(), + None, + caps_in_sq_mask_i.contiguous(), + ) + # full_logits_i : (i+1, cur_size, vocab_size) + logits_i = full_logits_i[-1] + del full_logits_i + # logits_i : (cur_size, vocab_size) + + if i < min_pred_size: + logits_i[:, eos_id] = -math.inf + + if use_forbid_rep: + prev_preds = preds[:, : i + 1] + prev_preds_ohot = ints_to_multihots(prev_preds, vocab_size, **bkwds) + prev_preds_ohot = prev_preds_ohot.logical_and_( + forbid_rep_mask.unsqueeze(dim=0) + ) + logits_i[prev_preds_ohot] = -math.inf + + next_toks_i = logits_i.argmax(dim=-1) + # next_toks_i shape: (cur_size,) + preds[:, i + 1] = next_toks_i + + if i < max_pred_size - 1: + is_unfinished_i = next_toks_i != eos_id + else: + is_unfinished_i = torch.full((logits_i.shape[0],), False, **bkwds) + + global_logits_out[batch_idxs, :, i] = logits_i + + preds = preds[is_unfinished_i] + batch_idxs = batch_idxs[is_unfinished_i] + frame_embs = frame_embs[:, is_unfinished_i] + frame_embs_pad_mask = frame_embs_pad_mask[is_unfinished_i] + + if preds.nelement() <= 0: + pred_size = i + 1 + break + + if pred_size < max_pred_size: + global_logits_out = global_logits_out[:, :, :pred_size].contiguous() + + # logits shape: (bsize, vocab_size, pred_size) + return global_logits_out diff --git a/src/conette/nn/encoders/cnn10.py b/src/conette/nn/encoders/cnn10.py new file mode 100644 index 000000000..bb0e78b60 --- /dev/null +++ b/src/conette/nn/encoders/cnn10.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# BASED ON Cnn10 class from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py#L484 + +import logging + +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import yaml + +from torch import Tensor +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from conette.nn.pann_utils.models import init_bn, init_layer, do_mixup +from conette.nn.pann_utils.models import ConvBlock +from conette.nn.pann_utils.ckpt import pann_load_state_dict + + +pylog = logging.getLogger(__name__) + + +class Cnn10(nn.Module): + AUDIOSET_NUM_CLASSES = 527 + CONV_FEATURES_EMB_SIZE = 512 + + def __init__( + self, + # Spectrogram extractor args + sr: int = 32000, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + # Other args + return_clip_outputs: bool = True, + return_frame_outputs: bool = False, + classes_num: int = 527, + clip_emb_size: int = 512, + frame_emb_size: int = 512, + freeze_weight: str = "none", + lens_rounding_mode: str = "trunc", + pretrained: bool = False, + use_specaug: bool = False, + waveform_input: bool = True, + convblock_dropout: float = 0.2, + freeze_spectro_extractor: bool = True, + freeze_logmel_extractor: bool = True, + use_fc2_layer: bool = False, + ) -> None: + """ + Compute frame-embeddings of shape (bsize, embed_len, n_frames) from audio. + + :param sr: defaults to 32000. + :param window_size: defaults to 1024. + :param hop_size: defaults to 320. + :param mel_bins: defaults to 64. + :param fmin: defaults to 50. + :param fmax: defaults to 14000. + :param add_clip_linear: TODO + :param add_frame_linear: TODO + :param classes_num: TODO + :param clip_emb_size: TODO + :param frame_emb_size: TODO + :param freeze_weight: TODO + :param lens_rounding_mode: TODO + :param pretrained: If True, use pretrained weights from PANN. defaults to True. + :param use_spec_augment: TODO + :param waveform_input: TODO + :param convblock_dropout: Dropout used after ConvBlocks. defaults to 0.2. + :param freeze_spectro_extractor: If true, freezes spectrogram extractor weights. defaults to True. + :param freeze_logmel_extractor: If true, freezes logmel extrator weights. defaults to True. + """ + if return_frame_outputs: + pylog.warning( + f"Deprecated argument value {return_frame_outputs=}. Please use projection in a separate module." + ) + + if not pretrained and freeze_weight != "none": + raise ValueError( + f"Cannot freeze weights without using pre-trained weights. (found {freeze_weight=} && {pretrained=})" + ) + if ( + pretrained + and return_clip_outputs + and classes_num != self.AUDIOSET_NUM_CLASSES + ): + pylog.warning( + f"Found argument {classes_num=} != {self.AUDIOSET_NUM_CLASSES} and {pretrained=}, so the layer 'fc_audioset' will not use pretrained weights." + ) + if pretrained and return_clip_outputs and clip_emb_size != 512: + raise ValueError( + f"Invalid argument {clip_emb_size=} with {pretrained=} and {return_clip_outputs=}." + ) + if not return_frame_outputs and frame_emb_size != 512: + raise ValueError( + f"Cannot remove Linear 'fc2' in CNN10 with {return_frame_outputs=} and {frame_emb_size=} != 512. Please use add_frame_linear=True or frame_emb_size=512." + ) + if lens_rounding_mode not in ("trunc", "ceil"): + raise ValueError( + f"Invalid argument {lens_rounding_mode=} (expected 'trunc' or 'ceil')" + ) + + super().__init__() + # Params + self.return_clip_outputs = return_clip_outputs + self.return_frame_outputs = return_frame_outputs + self.classes_num = classes_num + self.clip_emb_size = clip_emb_size + self.frame_emb_size = frame_emb_size + self.lens_rounding_mode = lens_rounding_mode + self.pretrained = pretrained + self.use_specaug = use_specaug + self.waveform_input = waveform_input + self.convblock_dropout = convblock_dropout + self.use_fc2_layer = use_fc2_layer + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + if self.waveform_input: + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_spectro_extractor, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sr, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=freeze_logmel_extractor, + ) + else: + self.spectrogram_extractor = nn.Identity() + self.logmel_extractor = nn.Identity() + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + + # For tag outputs + self.fc1 = ( + nn.Linear(Cnn10.CONV_FEATURES_EMB_SIZE, clip_emb_size, bias=True) + if return_clip_outputs + else nn.Identity() + ) + self.fc_audioset = ( + nn.Linear(clip_emb_size, classes_num, bias=True) + if return_clip_outputs + else nn.Identity() + ) + # For frame outputs (not pretrained !) + self.fc2 = ( + nn.Linear(Cnn10.CONV_FEATURES_EMB_SIZE, frame_emb_size, bias=True) + if return_frame_outputs and use_fc2_layer + else nn.Identity() + ) + + # Initialize weights + self.init_weight() + if pretrained: + exclude_spectro_params = window_size != 1024 or hop_size != 320 + self.load_pretrained_weights( + strict=False, exclude_spectro_params=exclude_spectro_params + ) + self.freeze_weight(freeze_weight) + + def init_weight(self) -> None: + init_bn(self.bn0) + for layer in (self.fc1, self.fc_audioset, self.fc2): + if hasattr(layer, "weight"): + init_layer(layer) + + def load_pretrained_weights( + self, + strict: bool = False, + exclude_spectro_params: bool = False, + ) -> None: + device = self.bn0.weight.device + state_dict = pann_load_state_dict("Cnn10", device, offline=False) + + if exclude_spectro_params: + state_dict = { + key: weight + for key, weight in state_dict.items() + if all( + not key.startswith(prefix) + for prefix in ("spectrogram_extractor", "logmel_extractor") + ) + } + if self.pretrained and self.classes_num != self.AUDIOSET_NUM_CLASSES: + state_dict = { + key: weight + for key, weight in state_dict.items() + if all(not key.startswith(prefix) for prefix in ("fc_audioset",)) + } + + self.load_state_dict(state_dict, strict=strict) # type: ignore + + def freeze_weight(self, freeze_weight: str) -> None: + if freeze_weight == "none": + return None + + if freeze_weight == "all": + excluded_lst = ["fc2"] + elif freeze_weight == "first1": + excluded_lst = ["fc2", "conv_block1"] + elif freeze_weight == "last1": + excluded_lst = ["fc2", "conv_block4"] + else: + raise RuntimeError(f"Unknown freeze encoder mode {freeze_weight=}.") + + if self.pretrained and self.classes_num != self.AUDIOSET_NUM_CLASSES: + excluded_lst.append("fc_audioset") + + pylog.debug(f"Freezing layers:\n{yaml.dump(excluded_lst, sort_keys=False)}.") + + for name, param in self.named_parameters(): + if all(not name.startswith(excluded) for excluded in excluded_lst): + param.requires_grad = False + self.eval() + + def _check_forward_input( + self, + x: Tensor, + x_shape: Optional[Tensor], + **kwargs, + ) -> None: + if self.waveform_input: + if not (x.ndim == 2 or (x.ndim == 3 and x.shape[1] == 1)): + raise ValueError( + f"Invalid input shape {x.shape=}. Expected (bsize, time_steps) or (bsize, 1, time_steps) tensor." + ) + else: + if not (x.ndim == 3 or (x.ndim == 4 and x.shape[1] == 1)): + raise ValueError( + f"Invalid input shape {x.shape=}. Expected (bsize, time_steps, freq_bins) or (bsize, 1, time_steps, freq_bins) tensor." + ) + + if x_shape is not None: + if x_shape.ndim != 2: + raise ValueError( + f"Invalid number of dimensions for x_shape argument. (expected 2 dims with shape (bsize, x.ndim-1) tensor but found {x_shape.ndim=})" + ) + if x.shape[0] != x_shape.shape[0]: + raise ValueError( + f"Invalid batch dim 0 for arguments x and x_shape. (found {x.shape[0]=} != {x_shape.shape[0]=})" + ) + if x_shape.shape[1] != x.ndim - 1: + raise ValueError( + f"Invalid x_shape dim 1 {x_shape.shape[1]=}. (expected {x.ndim-1=})" + ) + for x_shape_i in x_shape: + for j, x_shape_ij in enumerate(x_shape_i): + if x_shape_ij > x.shape[j + 1]: + raise ValueError( + f"Found a shape greater than the dimension of the input. (found {x_shape_ij} > {x.shape[j+1]})" + ) + + def forward( + self, + x: Tensor, + x_shape: Optional[Tensor], + mixup_params: Optional[dict] = None, + mixup_lambda: Optional[Tensor] = None, + ) -> dict[str, Any]: + """ + :param x: Batch of audios tensors. + Waveforms shapes (if self.waveform_input=True): + (bsize, 1, time_steps) or (bsize, time_steps) tensor + Spectrograms shapes (if self.waveform_input=False): + (bsize, 1, time_steps, freq_bins) or (bsize, time_steps, freq_bins) tensor + :param x_shape: Shape of non-padded audio of x. Has shape (bsize, x.ndim-1) + :param mixup_params: Dictionary of MixUp parameters. + 'lambda1': coefficient of the first tensor x + 'lambda2': coefficient of the second tensor x2 + 'indexes': tensor of indexes for shuffle x to x2, shape is (bsize,) + :returns: A dictionary with embeddings and logits. + 'frame_embs': (bsize, embed_size, time_steps_reduced) + 'frame_embs_lens': (bsize, embed_size) + "clip_embs": (bsize, 512), + "clip_logits": (bsize, classes_num=527), + """ + self._check_forward_input(x, x_shape) + + if self.waveform_input: + # Convert and format to spectrogram and get x_lens + if x.ndim == 3: + x = x.squeeze_(dim=1) + # x: (bsize, time_steps) + source_len = x.shape[-1] + x_lens = x_shape.squeeze(dim=1) if x_shape is not None else None + + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + else: + # Format spectrogram and get x_lens + if x.ndim == 3: + x = x.unsqueeze_(dim=1) + time_step_dim = 1 + else: # x.ndim == 4 + time_step_dim = 2 + + # x : (bsize, 1, time_steps, freq_bins) + source_len = x.shape[-2] + + # x_shape : (bsize, N) + x_lens = x_shape[:, time_step_dim - 1] if x_shape is not None else None + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training and self.use_specaug: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + # Mixup on spectrogram + if self.training and mixup_params is not None: + # x = do_mixup(x, mixup_lambda) + lambda1 = mixup_params["lambda1"] + lambda2 = mixup_params["lambda2"] + indexes = mixup_params["indexes"] + x = lambda1 * x + lambda2 * x[indexes] + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=self.convblock_dropout, training=self.training) + + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=self.convblock_dropout, training=self.training) + + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=self.convblock_dropout, training=self.training) + + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=self.convblock_dropout, training=self.training) + + x = torch.mean(x, dim=3) + conv_features = x + # conv_features : (bsize, n_filters=512, time_steps, mel_bins) + + outs = {} + + if self.return_clip_outputs: + x = conv_features + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + del x1, x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + clip_embs = F.dropout(x, p=0.5, training=self.training) + clip_logits = self.fc_audioset(x) + del x + outs |= { + "clip_embs": clip_embs, + "clip_logits": clip_logits, + } + + x = conv_features + x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + x = x1 + x2 + + if self.return_frame_outputs and self.use_fc2_layer: + x = F.dropout(x, p=0.5, training=self.training) + x = x.transpose(1, 2) + x = self.fc2(x) + x = F.relu_(x) + x = x.transpose(1, 2) + x = F.dropout(x, p=0.5, training=self.training) + + # x : (bsize, embed_size, time_steps_reduced) + if x_lens is not None: + if self.lens_rounding_mode == "trunc": + reduction_factor = source_len // x.shape[-1] + x_lens = x_lens.div(reduction_factor, rounding_mode="trunc") + elif self.lens_rounding_mode == "ceil": + reduction_factor = source_len / x.shape[-1] + x_lens = x_lens.float().div(reduction_factor).ceil() + elif self.lens_rounding_mode == "round": + reduction_factor = source_len / x.shape[-1] + x_lens = x_lens.float().div(reduction_factor).round() + else: + raise ValueError(f"Invalid parameter {self.lens_rounding_mode=}.") + + outs |= { + "frame_embs": x, + "frame_embs_lens": x_lens, + } + return outs diff --git a/src/conette/nn/encoders/cnn14.py b/src/conette/nn/encoders/cnn14.py new file mode 100644 index 000000000..13cb8b612 --- /dev/null +++ b/src/conette/nn/encoders/cnn14.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Any, Mapping, Optional + +import torch + +from torch import nn, Tensor +from torch.nn import functional as F +from torch.nn.modules.module import _IncompatibleKeys +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from conette.nn.pann_utils.models import ( + do_mixup, + init_bn, + init_layer, + ConvBlock, +) + + +pylog = logging.getLogger(__name__) + + +class Cnn14(nn.Module): + def __init__( + self, + sample_rate: int = 32000, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + classes_num: int = 527, + waveform_input: bool = True, + use_specaug: bool = True, + return_clip_outputs: bool = True, + return_frame_outputs: bool = False, + ) -> None: + super(Cnn14, self).__init__() + self.waveform_input = waveform_input + self.use_spec_aug = use_specaug + self.return_clip_output = return_clip_outputs + self.return_frame_output = return_frame_outputs + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + if self.waveform_input: + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + else: + self.spectrogram_extractor = nn.Identity() + self.logmel_extractor = nn.Identity() + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + if self.return_clip_output: + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + else: + self.fc1 = nn.Identity() + self.fc_audioset = nn.Identity() + + self.init_weight() + + def init_weight(self) -> None: + init_bn(self.bn0) + if self.return_clip_output: + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + ) -> _IncompatibleKeys: + exclude_keys = [] + if not self.waveform_input: + exclude_keys += [ + "spectrogram_extractor.stft.conv_real.weight", + "spectrogram_extractor.stft.conv_imag.weight", + "logmel_extractor.melW", + ] + if not self.return_clip_output: + exclude_keys += [ + "fc1.weight", + "fc1.bias", + "fc_audioset.weight", + "fc_audioset.bias", + ] + + if len(exclude_keys) > 0: + pylog.warning(f"Auto-exclude keys {tuple(exclude_keys)}.") + + state_dict = dict(state_dict) + for key in exclude_keys: + state_dict.pop(key, None) + + return super().load_state_dict(state_dict, strict) # type: ignore + + def forward( + self, + input_: Tensor, + input_shapes: Tensor, + mixup_lambda: Optional[Tensor] = None, + ) -> dict[str, Tensor]: + """ + Input: (batch_size, data_length) if waveform_input=True else (batch_size, 1, time_steps, mel_bins) + """ + + if self.waveform_input: + input_time_dim = -1 + x = self.spectrogram_extractor( + input_ + ) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + else: + input_time_dim = -2 + x = input_ + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training and self.use_spec_aug: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + output_dict = {} + + if self.return_frame_output: + frame_embs = x + + # input_: (bsize, n_channels=1, time_steps=1001, mel_bins=64) + # x: (bsize, emb_size=2048, time_steps=31) + + input_lens = input_shapes[:, input_time_dim] + reduction_factor = input_.shape[input_time_dim] // frame_embs.shape[-1] + frame_embs_lens = input_lens.div(reduction_factor, rounding_mode="trunc") + + output_dict |= { + # (bsize, embed=2048, n_frames=31) + "frame_embs": frame_embs, + # (bsize,) + "frame_embs_lens": frame_embs_lens, + } + + if self.return_clip_output: + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict |= {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict diff --git a/src/conette/nn/encoders/cnn14_decisionlevel_att.py b/src/conette/nn/encoders/cnn14_decisionlevel_att.py new file mode 100644 index 000000000..e44507341 --- /dev/null +++ b/src/conette/nn/encoders/cnn14_decisionlevel_att.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + +from torch import nn, Tensor +from torch.nn import functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from conette.nn.pann_utils.pytorch_utils import ( + do_mixup, + interpolate, + pad_framewise_output, +) +from conette.nn.pann_utils.models import AttBlock, ConvBlock, init_bn, init_layer +from conette.nn.pann_utils.ckpt import pann_load_state_dict +from conette.transforms.audio.cutoutspec import CutOutSpec +from conette.transforms.mixup import Mixup, sample_lambda + + +class Cnn14_DecisionLevelAtt(nn.Module): + def __init__( + self, + sr: int = 32000, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + classes_num: int = 527, + use_cutout: bool = False, + use_pann_mixup: bool = False, + use_spec_augment: bool = False, + return_clip_outputs: bool = False, + pretrained: bool = True, + freeze_weight: str = "none", + waveform_input: bool = True, + ) -> None: + if freeze_weight != "none" and pretrained: + raise RuntimeError( + f"Cannot freeze weights without using pre-trained weights. (found {freeze_weight=} && {pretrained=})" + ) + + super().__init__() + self.use_cutout = use_cutout + self.use_pann_mixup = use_pann_mixup + self.use_specaug = use_spec_augment + self.return_clip_outputs = return_clip_outputs + self.waveform_input = waveform_input + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + + if waveform_input: + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sr, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + else: + self.spectrogram_extractor = nn.Identity() + self.logmel_extractor = nn.Identity() + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.cutout = CutOutSpec(fill_value=float(fmin)) + self.mixup = Mixup(alpha=0.4, asymmetric=True) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + if self.return_clip_outputs: + self.att_block = AttBlock(2048, classes_num, activation="sigmoid") + else: + self.att_block = None + + # Initialize weights + self.init_weight() + + if pretrained: + self.load_pretrained_weights() + + self.freeze_weights(freeze_weight) + + def load_pretrained_weights(self, strict: bool = False) -> None: + device = self.fc1.weight.device + state_dict = pann_load_state_dict("Cnn14_DecisionLevelAtt", device, True) + self.load_state_dict(state_dict, strict=strict) + + def init_weight(self) -> None: + init_bn(self.bn0) + init_layer(self.fc1) + + def freeze_weights(self, freeze_mode: str) -> None: + if freeze_mode == "none": + pass + else: + if freeze_mode == "all": + excluded_lst = [] + elif freeze_mode == "first1": + excluded_lst = ["conv_block1"] + elif freeze_mode == "first2": + excluded_lst = ["conv_block1", "conv_block2"] + elif freeze_mode == "first3": + excluded_lst = ["conv_block1", "conv_block2", "conv_block3"] + elif freeze_mode == "last1": + excluded_lst = ["fc1"] + elif freeze_mode == "last2": + excluded_lst = ["fc1", "conv_block6"] + elif freeze_mode == "last3": + excluded_lst = ["fc1", "conv_block6", "conv_block5"] + else: + raise RuntimeError(f'Unknown freeze encoder mode "{freeze_mode=}".') + + for name, param in self.named_parameters(): + if all(not name.startswith(excluded) for excluded in excluded_lst): + param.requires_grad = False + + def forward( + self, + input_: Tensor, + input_shapes: Tensor, + mixup_params: Optional[dict[str, Tensor]] = None, + ) -> dict[str, Tensor]: + """ + :param input: (bsize, audio_len) + :param input_lens: (bsize, ...) + :param mixup_params: {'lambda1': float, 'lambda2': float, 'indexes': IntTensor of shape (bsize,)} or None + """ + + if self.waveform_input: + if len(input_.shape) != 2: + raise RuntimeError( + f'Model "{self.__class__.__name__}" expects raw audio batch tensor of shape (bsize, audio_len), but found shape {input_.shape}.' + ) + + x = self.spectrogram_extractor( + input_ + ) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + else: + x = input_ + + input_time_dim = -2 + frames_num = x.shape[2] + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training and self.use_specaug: + x = self.spec_augmenter(x) + + if self.training and self.use_cutout: + x = self.cutout(x) + + if self.training and self.use_pann_mixup: + mixup_lambda = sample_lambda(self.mixup.alpha, self.mixup.asymmetric) + indexes = torch.randperm(len(x)) + x = x * mixup_lambda + x[indexes] * (1.0 - mixup_lambda) + x = do_mixup(x, mixup_lambda) + + # Mixup on spectrogram + if self.training and mixup_params is not None: + lambda1 = mixup_params["lambda1"] + lambda2 = mixup_params["lambda2"] + indexes = mixup_params["indexes"] + x = lambda1 * x + lambda2 * x[indexes] + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = x.transpose(1, 2) + x = F.relu_(self.fc1(x)) + x = x.transpose(1, 2) + x = F.dropout(x, p=0.5, training=self.training) + + frame_embs = x + + input_lens = input_shapes[:, input_time_dim] + reduction_factor = input_.shape[input_time_dim] // frame_embs.shape[-1] + frame_embs_lens = input_lens.div(reduction_factor, rounding_mode="trunc") + + output_dict = { + # (bsize, embed=2048, n_frames) + "frame_embs": frame_embs, + # (bsize,) + "frame_embs_lens": frame_embs_lens, + } + + if self.return_clip_outputs: + assert self.att_block is not None + (clip_logits, _, segmentwise_output) = self.att_block(x) + segmentwise_output = segmentwise_output.transpose(1, 2) + + # Get framewise output + frame_logits = interpolate(segmentwise_output, self.interpolate_ratio) + frame_logits = pad_framewise_output(frame_logits, frames_num) + + output_dict |= { + # (bsize, n_frames, n_classes) + "frame_logits": frame_logits, + # (bsize, n_classes) + "clip_logits": clip_logits, + } + + return output_dict diff --git a/src/conette/nn/encoders/convnext.py b/src/conette/nn/encoders/convnext.py index 346fea19f..6b5066742 100644 --- a/src/conette/nn/encoders/convnext.py +++ b/src/conette/nn/encoders/convnext.py @@ -14,7 +14,7 @@ from conette.nn.modules.drop import DropPath from conette.nn.modules.norm import LayerNorm from conette.transforms.mixup import pann_mixup -from conette.transforms.speed_perturb import SpeedPerturbation +from conette.transforms.audio.speed_perturb import SpeedPerturbation class ConvNeXtBlock(nn.Module): diff --git a/src/conette/nn/functional/misc.py b/src/conette/nn/functional/misc.py new file mode 100644 index 000000000..638d69139 --- /dev/null +++ b/src/conette/nn/functional/misc.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import copy + +from typing import ( + Any, + Callable, + Generator, + Iterable, + Mapping, + Optional, + Sequence, + TypeVar, + Union, +) + +import torch + +from torch import nn, Tensor +from torch.nn import functional as F +from torch.nn.utils.rnn import pad_sequence + +from conette.nn.functional.crop import crop_dim +from conette.nn.functional.mask import tensor_to_pad_mask +from conette.nn.functional.pad import pad_dim + + +T = TypeVar("T") + + +def sharpen(x: Tensor, temperature: float, dim: int) -> Tensor: + x = x ** (1.0 / temperature) + x = x / x.norm(p=1, dim=dim, keepdim=True) # type: ignore + return x + + +def sort_batch_by_lengths( + batch: dict[str, Tensor], + key: str = "audio_lens", +) -> dict[str, Tensor]: + indexes = torch.argsort(batch["audio_lens"]) + + keys = list(batch.keys()) + result = {} + for key in keys: + value = batch[key] + del batch[key] + if isinstance(value, Tensor): + result[key] = value[indexes] + elif isinstance(value, Sequence): + result[key] = [value[i] for i in indexes] + else: + raise ValueError(f"Unsupported value type. ({value=})") + return result + + +def count_params(model: nn.Module, only_trainable: bool = False) -> int: + return sum( + param.numel() + for param in model.parameters() + if not only_trainable or param.requires_grad + ) + + +def module_eq(m1: nn.Module, m2: nn.Module) -> bool: + n_params1 = sum(1 for _ in m1.parameters()) + n_params2 = sum(1 for _ in m2.parameters()) + return n_params1 == n_params2 and all( + p1.shape == p2.shape and p1.eq(p2).all() + for p1, p2 in zip(m1.parameters(), m2.parameters()) + ) + + +def module_mean(modules: Iterable[nn.Module], with_buffers: bool = True) -> nn.Module: + modules = list(modules) + assert len(modules) > 0 + + output = copy.deepcopy(modules[0]) + + all_params = [output.parameters()] + [module.parameters() for module in modules] + for params in zip(*all_params): + params[0][:] = torch.stack(params[1:]).mean(dim=0) + + if with_buffers: + all_buffers = [output.buffers()] + [module.buffers() for module in modules] + for buffers in zip(*all_buffers): + if buffers[0].is_floating_point(): + buffers[0][:] = torch.stack(buffers[1:]).mean(dim=0) + + return output + + +def tensor_eq(t1: Tensor, t2: Tensor) -> bool: + return t1.shape == t2.shape and bool(t1.eq(t2).all().item()) + + +def tensor_close( + t1: Tensor, + t2: Tensor, + rtol: float = 0.00001, + atol: float = 1e-8, + equal_nan: bool = False, +) -> bool: + return t1.shape == t2.shape and torch.allclose( + t1, t2, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + +def tensors_list_to_tensor( + tensors: list[Tensor], + pad_value: float, + batch_first: bool = True, +) -> Tensor: + """Pad a list of tensors to a tensor. + + :param tensors: List of N tensors with the same number of dims and with a different size at the first dim. + :param pad_value: The value used for fill the tensors. + :returns: (N, *) + """ + return pad_sequence(tensors, batch_first=batch_first, padding_value=pad_value) + + +def reduce_mask( + mask: Tensor, + target_size: int, + dim: int = -1, + round_fn: Union[str, Callable[[Tensor], Tensor]] = "round", +) -> Tensor: + if isinstance(round_fn, str): + if round_fn == "round": + round_fn = torch.round + elif round_fn == "floor": + round_fn = torch.floor + elif round_fn == "ceil": + round_fn = torch.ceil + else: + raise ValueError(f"Invalid argument {round_fn=}.") + + orig_size = mask.shape[dim] + factor = target_size / orig_size + indexes = round_fn(torch.arange(0, orig_size, step=factor)) + slices: list[Any] = [slice(None) for _ in range(mask.ndim)] + slices[dim] = indexes + return mask[slices] + + +def stack_tensors_rec( + sequence: Union[Tensor, int, float, tuple, list], + relaxed: bool = False, + dtype: Union[None, torch.dtype] = None, + device: Union[str, torch.device, None] = "auto", +) -> Union[Tensor, list]: + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + def stack_tensors_rec_impl( + seq: Union[Tensor, int, float, tuple, list] + ) -> Union[Tensor, list]: + if isinstance(seq, Tensor): + return seq.to(dtype=dtype, device=device) + elif isinstance(seq, (int, float)): + return torch.as_tensor(seq, dtype=dtype, device=device) # type: ignore + elif isinstance(seq, (list, tuple)): + if all(isinstance(elt, (int, float)) for elt in seq): # type: ignore + return torch.as_tensor(seq, dtype=dtype, device=device) # type: ignore + + seq = [stack_tensors_rec_impl(elt) for elt in seq] # type: ignore + if all(isinstance(elt, Tensor) for elt in seq): + shapes = [elt.shape for elt in seq] # type: ignore + if len(seq) == 0 or all(shape == shapes[0] for shape in shapes): + return torch.stack(seq) # type: ignore + elif relaxed: + return seq + else: + raise ValueError( + f"Cannot stack tensors of different shapes. (found {shapes=})" + ) + elif relaxed: + return seq + else: + raise ValueError("Cannot stack tensors of different shape or types.") + else: + raise TypeError( + f"Invalid type {type(seq)}. (expected Tensor, int, float, list or tuple)" + ) + + sequence = stack_tensors_rec_impl(sequence) + return sequence + + +def batch_conv2d_naive(x: Tensor, weight: Tensor) -> Tensor: + """ + Conv2d with a batch of distincts weights. (slow version using Conv2d multiple times) + + :param x: (bsize, in_channels, x_width, x_height) + :param weight: (bsize, out_channels, in_channels, weight_width, weight_height) + :returns: (bsize, out_channels, x_width, x_height) + """ + if ( + x.ndim != 4 + or weight.ndim != 5 + or x.shape[0] != weight.shape[0] + or x.shape[1] != weight.shape[2] + ): + raise ValueError( + f"Invalid arguments for batch_conv2d_naive. ({x.shape=}; {weight.shape=})" + ) + + x = torch.stack( + [ + F.conv2d(x_i.unsqueeze(dim=0), weight=w_i, bias=None, padding="same") + for x_i, w_i in zip(x, weight) + ] + ) + x = x.squeeze(dim=1) + return x.contiguous() + + +def batch_conv2d(x: Tensor, weight: Tensor) -> Tensor: + """ + Conv2d with a batch of distincts weights. (faster version using only 1 Conv2d with groups) + + :param x: (bsize, in_channels, x_width, x_height) + :param weight: (bsize, out_channels, in_channels, weight_width, weight_height) + :returns: (bsize, out_channels, x_width, x_height) + """ + if ( + x.ndim != 4 + or weight.ndim != 5 + or x.shape[0] != weight.shape[0] + or x.shape[1] != weight.shape[2] + ): + raise ValueError( + f"Invalid arguments for batch_conv2d. ({x.shape=}; {weight.shape=})" + ) + + x_width, x_height = x.shape[2:] + bsize, out_channels, in_channels, weight_width, weight_height = weight.shape + x = x.view(1, bsize * in_channels, x_width, x_height).contiguous() + weight = weight.view( + bsize * out_channels, in_channels, weight_width, weight_height + ).contiguous() + x = F.conv2d(x, weight=weight, bias=None, padding="same", groups=bsize) + x = x.view(bsize, out_channels, x_width, x_height) + return x.contiguous() + + +def move_to_rec( + x: Any, + *args, + predicate: Optional[Callable[[Any], bool]] = None, + **kwargs, +) -> Any: + if isinstance(x, (Tensor, nn.Module)): + if predicate is None or predicate(x): + return x.to(*args, **kwargs) + else: + return x + elif isinstance(x, (str, float, int)): + return x + elif isinstance(x, Mapping): + return { + k: move_to_rec(v, predicate=predicate, *args, **kwargs) + for k, v in x.items() + } + elif isinstance(x, Iterable): + generator = (move_to_rec(xi, predicate=predicate, *args, **kwargs) for xi in x) + if isinstance(x, Generator): + return generator + elif isinstance(x, tuple): + return tuple(generator) + elif isinstance(x, list): + return list(generator) + else: + return list(generator) + else: + return x + + +def pad_crop_dim( + x: Tensor, + target_length: int, + align: str = "left", + fill_value: float = 0.0, + dim: int = -1, + mode: str = "constant", +) -> Tensor: + if x.shape[dim] == target_length: + return x + elif x.shape[dim] > target_length: + return crop_dim(x, target_length, align, dim) + else: + return pad_dim(x, target_length, align, fill_value, dim, mode) + + +def pad_and_cat( + tensors: Iterable[Tensor], + dim_pad: int, + dim_cat: int, + fill_value: float, + align: str = "left", +) -> Tensor: + target_len = max(tensor.shape[dim_pad] for tensor in tensors) + tensors = [ + pad_dim(tensor, target_len, align, fill_value, dim_pad) for tensor in tensors + ] + tensors = torch.cat(tensors, dim=dim_cat) + return tensors + + +def pad_and_stack( + tensors: Iterable[Tensor], + dim_pad: int, + fill_value: float, + align: str = "left", +) -> Tensor: + target_len = max(tensor.shape[dim_pad] for tensor in tensors) + tensors = [ + pad_dim(tensor, target_len, align, fill_value, dim_pad) for tensor in tensors + ] + tensors = torch.stack(tensors) + return tensors + + +def can_be_stacked(tensors: list[Tensor]) -> bool: + """Returns True if the list contains tensor that can be stacked with :func:`~torch.stack`.""" + return len(tensors) == 0 or all( + tensor.shape == tensors[0].shape for tensor in tensors + ) + + +def can_be_stacked_v2(tensors: list[Tensor]) -> bool: + """Returns True if the list contains tensor that can be stacked with :func:`~torch.stack`.""" + if len(tensors) == 0: + return False + else: + shapes = torch.as_tensor([tensor.shape for tensor in tensors]) + return shapes.eq(shapes[0]).all().item() # type: ignore + + +def can_be_padded(tensors: list[Tensor]) -> bool: + """Returns True if the list contains tensor that can be stacked with :func:`~torch.nn.utils.rnn.pad_sequence`.""" + return len(tensors) == 0 or all( + tensor.shape[1:] == tensors[0].shape[1:] for tensor in tensors + ) + + +def check_pred( + pred: Tensor, + pad_id: int = 0, + bos_id: int = 1, + eos_id: int = 2, + unk_id: int = 3, +) -> tuple[bool, bool, bool, bool]: + """Check if a prediction tensor is valid. + + :param pred: (bsize, pred_len) + :returns: (sos_at_start, eos_at_end, no_unk, pad_at_end) + """ + assert pred.ndim == 2 + dim = 1 + sos_at_start = pred[:, 0].eq(bos_id).all().item() + contains_eos = (pred == eos_id).any(dim=dim) + eos_at_end = contains_eos.all().item() + no_unk = pred.ne(unk_id).all().item() + + indexes_eos = (pred == eos_id).int().argmax(dim=dim) + lengths = torch.where(contains_eos, indexes_eos, pred.shape[dim]) + pad_at_end = True + for pred_i, len_i in zip(pred, lengths): + pad_at_end = pad_at_end and pred_i[len_i + 1 :].eq(pad_id).all().item() + + return sos_at_start, eos_at_end, no_unk, pad_at_end # type: ignore + + +def find( + tensor: Tensor, + value: Any, + default: Union[None, Tensor, int, float] = None, + dim: int = -1, +) -> Tensor: + """Return the index of the first occurrence of value in a tensor.""" + assert tensor.ndim > 0 + mask = tensor.eq(value) + contains_eos = mask.any(dim=dim) + indexes_eos = mask.int().argmax(dim=dim) + + if default is None: + if not contains_eos.all(): + raise RuntimeError(f"Cannot find {value=} in tensor.") + return indexes_eos + else: + output = torch.where(contains_eos, indexes_eos, default) + return output + + +def pad_after_eos(pred: Tensor, eos_id: int, pad_id: int) -> Tensor: + """ + :param pred: (bsize, pred_size) + :returns: (bsize, pred_size) + """ + pad_mask = tensor_to_pad_mask( + pred, + end_value=eos_id, + include_end=False, + ) + pred[pad_mask] = pad_id + return pred + + +def prepend_value(pred: Tensor, value: Union[int, float, bool], dim: int = 1) -> Tensor: + """ + :param pred: (bsize, pred_size) + :returns: (bsize, pred_size+1) + """ + shape = list(pred.shape) + shape[dim] = 1 + values = torch.full(shape, value, dtype=pred.dtype, device=pred.device) + pred = torch.cat((values, pred), dim=dim) + return pred + + +def cat_padded_batch( + x1: Tensor, + x1_lens: Tensor, + x2: Tensor, + x2_lens: Tensor, + seq_dim: int, + batch_dim: int = 0, +) -> tuple[Tensor, Tensor]: + assert x1.ndim == x2.ndim + assert x1_lens.ndim == x2_lens.ndim == 1 + assert ( + x1.shape[batch_dim] + == x2.shape[batch_dim] + == x1_lens.shape[0] + == x2_lens.shape[0] + ) + + x12_lens = x1_lens + x2_lens + sum_size_12 = x1.shape[seq_dim] + x2.shape[seq_dim] + + x12 = pad_dim(x1, sum_size_12, dim=seq_dim) + kwd: dict[str, Any] = dict(device=x1.device, dtype=torch.long) + indexes = torch.arange(x2_lens.max().item(), **kwd) + + unsq_x1_lens = x1_lens + ndim = x1.ndim + for i in range(ndim): + if i != (seq_dim % ndim): + indexes = indexes.unsqueeze(dim=i) + if i != (batch_dim % ndim): + unsq_x1_lens = unsq_x1_lens.unsqueeze(dim=i) + + expand_size = list(x2.shape) + expand_size[seq_dim] = -1 + indexes = indexes.expand(*expand_size) + indexes = indexes + unsq_x1_lens + x12.scatter_(seq_dim, indexes, x2) + + max_size_12 = int(x12_lens.max().item()) + if max_size_12 < sum_size_12: + slices = [slice(None) for _ in range(ndim)] + slices[seq_dim] = slice(max_size_12) + x12 = x12[slices] + x12 = x12.contiguous() + return x12, x12_lens + + +def detect_time_dim(shapes: Tensor, default_if_all_eq: int = -1) -> int: + assert shapes.ndim == 2 + dim = default_if_all_eq + for i in range(shapes.shape[1]): + lens_i = shapes[:, i] + if lens_i.ne(lens_i[0]).any(): + dim = i + break + return dim diff --git a/src/conette/nn/functional/pad.py b/src/conette/nn/functional/pad.py index b032421eb..4f5503c06 100644 --- a/src/conette/nn/functional/pad.py +++ b/src/conette/nn/functional/pad.py @@ -1,16 +1,147 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Iterable +from typing import Any, Iterable, Sized, Union import torch from torch import Tensor from torch.nn import functional as F +from torch.nn.utils.rnn import pad_sequence + PAD_ALIGNS = ("left", "right", "center", "random") +def pad_sequence_rec( + sequence: Union[Tensor, int, float, tuple, list], + pad_value: float, + dtype: Union[None, torch.dtype] = None, + device: Union[str, torch.device, None] = None, +) -> Tensor: + """Recursive version of torch.nn.utils.rnn.pad_sequence, with padding of Tensors. + + :param sequence: The sequence to pad. Must be convertable to tensor by having the correct number of dims in all sublists. + :param pad_value: The pad value used. + :param dtype: The dtype of the output Tensor. defaults to None. + :param device: The device of the output Tensor. defaults to None. + :returns: The sequence as a padded Tensor. + + Example 1 + ---------- + >>> sequence = [[1, 2], [3], [], [4, 5]] + >>> output = pad_sequence_rec(sequence, 0) + tensor([[1, 2], [3, 0], [0, 0], [4, 5]]) + + Example 2 + ---------- + >>> invalid_sequence = [[1, 2, 3], 3] + >>> output = pad_sequence_rec(invalid_sequence, 0) + ValueError : Cannot pad sequence of tensors of differents number of dims. + + """ + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + if isinstance(sequence, Tensor): + return sequence.to(dtype=dtype, device=device) + + if isinstance(sequence, (int, float)) or ( + isinstance(sequence, Sized) and len(sequence) == 0 + ): + return torch.as_tensor(sequence, dtype=dtype, device=device) # type: ignore + + elif isinstance(sequence, (list, tuple)): + if all(isinstance(elt, (int, float)) for elt in sequence): + return torch.as_tensor(sequence, dtype=dtype, device=device) # type: ignore + + sequence = [pad_sequence_rec(elt, pad_value, dtype, device) for elt in sequence] + # sequence is now a list[Tensor] + shapes = [elt.shape for elt in sequence] + + # If all tensors have the same shape + if all(shape == shapes[0] for shape in shapes): + return torch.stack(sequence, dim=0) + + # If all tensors have the same number of dims + elif all(elt.ndim == sequence[0].ndim for elt in sequence): + if all(shape[1:] == shapes[0][1:] for shape in shapes): + return pad_sequence(sequence, True, pad_value) + else: + max_lens = [ + max(shape[i] for shape in shapes) for i in range(sequence[0].ndim) + ] + paddings = [ + [ + (max_lens[i] - elt.shape[i]) * j + for i in range(-1, -sequence[0].ndim, -1) + for j in range(2) + ] + for elt in sequence + ] + sequence = [ + F.pad(elt, padding, value=pad_value) + for elt, padding in zip(sequence, paddings) + ] + return pad_sequence(sequence, True, pad_value) + + else: + raise ValueError( + f"Cannot pad sequence of tensors of differents number of dims. ({sequence=}, {shapes=})" + ) + + else: + raise TypeError( + f"Invalid type {type(sequence)}. (expected Tensor, int, float, list or tuple)" + ) + + +def pad_sequence_1d(tensors: list[Tensor], pad_value: float) -> Tensor: + if not all(tensor.ndim == 1 for tensor in tensors): + raise ValueError("Invalid argument tensors for pad_sequence_1d.") + + max_len = max(tensor.shape[0] for tensor in tensors) + output = torch.empty( + (len(tensors), max_len), device=tensors[0].device, dtype=tensors[0].dtype + ) + for i, tensor in enumerate(tensors): + output[i, : tensor.shape[0]] = tensor + output[i, tensor.shape[0] :] = pad_value + return output + + +def pad_sequence_nd(tensors: list[Tensor], pad_value: float) -> Tensor: + if not all(tensor.ndim >= 1 for tensor in tensors): + raise ValueError("Invalid argument tensors for pad_sequence_1d.") + if not all(tensor.shape[1:] == tensors[0].shape[1:] for tensor in tensors[1:]): + raise ValueError("Invalid argument tensors for pad_sequence_1d.") + + max_len = max(tensor.shape[0] for tensor in tensors) + output = torch.empty( + (len(tensors), max_len) + tuple(tensors[0].shape[1:]), + device=tensors[0].device, + dtype=tensors[0].dtype, + ) + for i, tensor in enumerate(tensors): + output[i, : tensor.shape[0]] = tensor + output[i, tensor.shape[0] :] = pad_value + return output + + +def pad_last_dim(tensor: Tensor, target_length: int, pad_value: float) -> Tensor: + """Left padding tensor at last dim. + + :param tensor: Tensor of at least 1 dim. (..., T) + :param target_length: Target length of the last dim. If target_length <= T, the function has no effect. + :param pad_value: Fill value used to pad tensor. + :returns: A tensor of shape (..., target_length). + """ + pad_size = max(target_length - tensor.shape[-1], 0) + return F.pad(tensor, [0, pad_size], value=pad_value) + + def pad_dim( x: Tensor, target_length: int, @@ -49,6 +180,95 @@ def pad_dim( return x +def pad_dims( + x: Tensor, + target_lengths: Union[int, Iterable[int]], + aligns: Union[str, Iterable[str]] = "left", + fill_value: float = 0.0, + dims: Iterable[int] = (-1,), + mode: str = "constant", +) -> Tensor: + """Generic function to pad multiple dimensions.""" + dims = list(dims) + if len(dims) == 0: + raise ValueError( + f"Invalid argument {dims=}. (cannot use an empty list of dimensions)" + ) + + if isinstance(target_lengths, int): + target_lengths = [target_lengths] * len(dims) + else: + target_lengths = list(target_lengths) + + if isinstance(aligns, str): + aligns = [aligns] * len(dims) + else: + aligns = list(aligns) + + if len(target_lengths) != len(dims): + raise ValueError( + f"Invalid number of targets lengths ({len(target_lengths)}) with the number of dimensions ({len(dims)})." + ) + + if len(aligns) != len(dims): + raise ValueError( + f"Invalid number of aligns ({len(aligns)}) with the number of dimensions ({len(dims)})." + ) + + pad_seq = [0 for _ in range(len(x.shape) * 2)] + for target_length, dim, align in zip(target_lengths, dims, aligns): + missing = max(target_length - x.shape[dim], 0) + + if align == "left": + missing_left = 0 + missing_right = missing + elif align == "right": + missing_left = missing + missing_right = 0 + elif align == "center": + missing_left = missing // 2 + missing % 2 + missing_right = missing // 2 + elif align == "random": + missing_left = int(torch.randint(low=0, high=missing + 1, size=()).item()) + missing_right = missing - missing_left + else: + ALIGNS = ("left", "right", "center", "random") + raise ValueError(f"Invalid argument {align=}. (expected one of {ALIGNS})") + + # Note: pad_seq : [pad_left_dim_-1, pad_right_dim_-1, pad_left_dim_-2, pad_right_dim_-2, ...) + idx = len(x.shape) - (dim % len(x.shape)) - 1 + assert pad_seq[idx * 2] == 0 and pad_seq[idx * 2 + 1] == 0 + pad_seq[idx * 2] = missing_left + pad_seq[idx * 2 + 1] = missing_right + + x = F.pad(x, pad_seq, mode=mode, value=fill_value) + return x + + +def stack_tensors(tensors: Iterable[Tensor], pad_value: float) -> Tensor: + tensors = list(tensors) + if len(tensors) == 0: + raise ValueError(f"Invalid argument {tensors=}.") + + d0_sum = sum(tensor.shape[0] for tensor in tensors) + max_shapes = tuple( + max(tensor.shape[i] for tensor in tensors) for i in range(1, tensors[0].ndim) + ) + + factory_kws: dict[str, Any] = dict(dtype=tensors[0].dtype, device=tensors[0].device) + output = torch.full((d0_sum,) + max_shapes, pad_value, **factory_kws) + + d0_start = 0 + for tensor in tensors: + d0_end = d0_start + tensor.shape[0] + slices = (slice(d0_start, d0_end),) + tuple( + slice(shape_i) for shape_i in tensor.shape[1:] + ) + output[slices] = tensor + d0_start = d0_end + return output + + def pad_and_stack(x: Iterable[Tensor], dim: int = -1) -> Tensor: if isinstance(x, Tensor): return x diff --git a/src/conette/nn/modules/misc.py b/src/conette/nn/modules/misc.py new file mode 100644 index 000000000..116232038 --- /dev/null +++ b/src/conette/nn/modules/misc.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import inspect +import tqdm + +from dataclasses import asdict, astuple +from typing import Any, Callable, Iterable, Mapping, Optional + +import torch + +from torch import nn, Tensor + + +class AmplitudeToLog(nn.Module): + def __init__(self, eps: float = torch.finfo(torch.float).eps) -> None: + super().__init__() + self.eps = eps + + def forward(self, data: Tensor) -> Tensor: + return torch.log(data + self.eps) + + +class Lambda(nn.Module): + def __init__(self, fn: Callable, **default_kwargs) -> None: + """Wrap a callable function or object to a Module.""" + super().__init__() + self.fn = fn + self.default_kwargs = default_kwargs + + def forward(self, *args, **kwargs) -> Any: + kwargs = self.default_kwargs | kwargs + return self.fn(*args, **kwargs) + + def extra_repr(self) -> str: + if isinstance(self.fn, nn.Module): + return "" + elif inspect.isfunction(self.fn): + return self.fn.__name__ + elif inspect.ismethod(self.fn): + return self.fn.__qualname__ + else: + return self.fn.__class__.__name__ + + +class Reshape(nn.Module): + def __init__(self, shape: tuple[int, ...]) -> None: + super().__init__() + self.shape = shape + + def forward(self, x: Tensor) -> Tensor: + return torch.reshape(x, self.shape) + + +class Print(nn.Module): + def __init__( + self, + preprocess: Optional[Callable] = None, + prefix: str = "DEBUG - ", + ) -> None: + super().__init__() + self._preprocess = preprocess + self._prefix = prefix + + def forward(self, x: Any) -> Any: + x_out = x + if self._preprocess is not None: + x = self._preprocess(x) + print(f"{self._prefix}{x=}") + return x_out + + +class AsTensor(nn.Module): + def __init__(self, **kwargs) -> None: + super().__init__() + self.kwargs = kwargs + + def forward(self, inp: list, *args, **kwargs) -> Tensor: + kwargs = self.kwargs | kwargs + return torch.as_tensor(inp, *args, **kwargs) + + def extra_repr(self) -> str: + kwargs_str = ",".join(f"{k}={v}" for k, v in self.kwargs.items()) + return f"kwargs=dict({kwargs_str})" + + +class ParallelDict(nn.ModuleDict): + """Compute output of each submodule value when forward(.) is called.""" + + def __init__( + self, modules: Optional[dict[str, nn.Module]] = None, verbose: bool = False + ) -> None: + super().__init__(modules) + self._verbose = verbose + + def forward(self, *args, **kwargs) -> dict[str, Any]: + tqdm_obj = tqdm.tqdm( + self.items(), desc=f"{self.__class__.__name__}", disable=not self._verbose + ) + outs = {} + for name, module in tqdm_obj: + tqdm_obj.set_description( + f"{self.__class__.__name__}:{module.__class__.__name__}" + ) + outs[name] = module(*args, **kwargs) + return outs + + +class ParallelList(nn.ModuleList): + def __init__( + self, modules: Optional[Iterable[nn.Module]] = (), verbose: bool = False + ) -> None: + super().__init__(modules) + self._verbose = verbose + + def forward(self, *args, **kwargs) -> list[Any]: + tqdm_obj = tqdm.tqdm( + self, + disable=not self._verbose, + desc=f"{self.__class__.__name__}", + ) + outs = [] + for module in tqdm_obj: + tqdm_obj.set_description( + f"{self.__class__.__name__}:{module.__class__.__name__}" + ) + outs.append(module(*args, **kwargs)) + return outs + + +class SequentialArgs(nn.Sequential): + def forward(self, *args) -> Any: + x = args + for module in self: + if isinstance(x, tuple): + x = module(*x) + else: + x = module(x) + return x + + +class SequentialKwargs(nn.Sequential): + def forward(self, **kwargs) -> Any: + x = kwargs + for module in self: + if isinstance(x, dict): + x = module(**x) + else: + x = module(x) + return x + + +class Standardize(nn.Module): + def __init__(self, unbiased_std: bool = True) -> None: + super().__init__() + self.unbiased_std = unbiased_std + + def forward(self, x: Tensor) -> Tensor: + x = (x - x.mean()) / x.std(unbiased=self.unbiased_std) + return x + + +class AsDict(nn.Module): + def forward(self, x: Any) -> dict[str, Any]: + return asdict(x) + + +class AsTuple(nn.Module): + def forward(self, x: Any) -> tuple[Any, ...]: + return astuple(x) + + +class DictTransformModule(nn.ModuleDict): + """Wrap a dictionary of modules to apply to each value of a dictionary input at a corresponding key. + + Example 1 + ---------- + ```py + >>> mean_a = DictTransformModule({"a": nn.ReLU()}) + >>> input = {"a": torch.as_tensor([-1., 2, -3]), "b": "something", "c": torch.as_tensor([1, 2, 3])} + >>> mean_a(input) + ... {"a": tensor([0.0, 2.0, 0.0]), "b": "something", "c": tensor([1, 2, 3])} + ``` + """ + + def __init__( + self, modules: Optional[Mapping[str, Optional[nn.Module]]] = None, **kwargs + ) -> None: + if modules is None: + modules = {} + else: + modules = dict(modules) + modules = modules | kwargs + modules = {k: v for k, v in modules.items() if v is not None} + super().__init__(modules) + + def forward(self, dic: dict[str, Any]) -> dict[str, Any]: + for name, module in self.items(): + if name in dic: + dic[name] = module(dic[name]) + return dic + + +class IdMapping(nn.Module): + def __init__(self, mapper: Tensor) -> None: + super().__init__() + self.mapper = mapper + + @classmethod + def from_dic( + cls, dic: Mapping[int, int], dtype: torch.dtype = torch.long + ) -> "IdMapping": + max_src_id = max(dic.keys()) + mapper = torch.zeros((max_src_id,), dtype=dtype) + for k, v in dic.items(): + mapper[k] = v + return IdMapping(mapper) + + def forward(self, ids: Tensor) -> Tensor: + assert not ids.is_floating_point() + return self.mapper[ids] diff --git a/src/conette/nn/pann_utils/__init__.py b/src/conette/nn/pann_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/conette/nn/pann_utils/ckpt.py b/src/conette/nn/pann_utils/ckpt.py new file mode 100644 index 000000000..944813b2b --- /dev/null +++ b/src/conette/nn/pann_utils/ckpt.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os.path as osp + +from typing import Union + +import torch + +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +# Zenodo link : https://zenodo.org/record/3987831 +# Hash type : md5 +PANN_PRETRAINED_URLS = { + "Cnn10": { + "model": "Cnn10", + "url": "https://zenodo.org/record/3987831/files/Cnn10_mAP%3D0.380.pth?download=1", + "hash": "bfb1f1f9968938fa8ef4012b8471f5f6", + "fname": "Cnn10_mAP_0.380.pth", + }, + "Cnn14_DecisionLevelAtt": { + "model": "Cnn14_DecisionLevelAtt", + "url": "https://zenodo.org/record/3987831/files/Cnn14_DecisionLevelAtt_mAP%3D0.425.pth?download=1", + "hash": "c8281ca2b9967244b91d557aa941e8ca", + "fname": "Cnn14_DecisionLevelAtt_mAP_0.425.pth", + }, + "Cnn14": { + "model": "Cnn14", + "url": "https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth?download=1", + "hash": "541141fa2ee191a88f24a3219fff024e", + "fname": "Cnn14_mAP_0.431.pth", + }, + "Cnn6": { + "model": "Cnn6", + "url": "https://zenodo.org/record/3987831/files/Cnn6_mAP%3D0.343.pth?download=1", + "hash": "e25e26b84585b14c7754c91e48efc9be", + "fname": "Cnn6_mAP_0.343.pth", + }, + "ResNet22": { + "model": "ResNet22", + "url": "https://zenodo.org/record/3987831/files/ResNet22_mAP%3D0.430.pth?download=1", + "hash": "cf36d413096793c4e15dc752a3abd599", + "fname": "ResNet22_mAP_0.430.pth", + }, + "ResNet38": { + "model": "ResNet38", + "url": "https://zenodo.org/record/3987831/files/ResNet38_mAP%3D0.434.pth?download=1", + "hash": "bf12f36aaabac4e0855e22d3c3239c1b", + "fname": "ResNet38_mAP_0.434.pth", + }, + "ResNet54": { + "model": "ResNet54", + "url": "https://zenodo.org/record/3987831/files/ResNet54_mAP%3D0.429.pth?download=1", + "hash": "4f1f1406d37a29e2379916885e18c5f3", + "fname": "ResNet54_mAP_0.429.pth", + }, + "Wavegram_Cnn14": { + "model": "Wavegram_Cnn14", + "url": "https://zenodo.org/record/3987831/files/Wavegram_Cnn14_mAP%3D0.389.pth?download=1", + "hash": "1e3506ab640371e0b5a417b15fd66d21", + "fname": "Wavegram_Cnn14_mAP_0.389.pth", + }, + "Wavegram_Logmel_Cnn14": { + "model": "Wavegram_Logmel_Cnn14", + "url": "https://zenodo.org/record/3987831/files/Wavegram_Logmel_Cnn14_mAP%3D0.439.pth?download=1", + "hash": "17fa9ab65af3c0eb5ffbc5f65552c4e1", + "fname": "Wavegram_Logmel_Cnn14_mAP_0.439.pth", + }, +} + + +def pann_get_ckpt_dir_path() -> str: + """Return the path to the directory containing PANN checkpoints files.""" + return osp.join(torch.hub.get_dir(), "checkpoints") + + +def pann_get_ckpt_path(model_name: str) -> str: + """Return the path to the PANN checkpoint file.""" + if model_name not in PANN_PRETRAINED_URLS: + raise ValueError( + f"Invalid argument {model_name=}. (expected one of {tuple(PANN_PRETRAINED_URLS.keys())})" + ) + + fname = PANN_PRETRAINED_URLS[model_name]["fname"] + fpath = osp.join(pann_get_ckpt_dir_path(), fname) + return fpath + + +def pann_load_state_dict( + model_name_or_path: str, + device: Union[str, torch.device, None] = None, + offline: bool = False, + verbose: int = 0, +) -> dict[str, Tensor]: + """Load PANN state_dict weights. + + :param model_name_or_path: Model name (case sensitive) or path to PANN checkpoint file. + :param device: Device of checkpoint weights. defaults to None. + :param offline: If False, the checkpoint from a model name will be automatically downloaded. + defaults to False. + :param verbose: Verbose level. defaults to 0. + :returns: State dict of model weights. + """ + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + if osp.isfile(model_name_or_path): + model_path = model_name_or_path + else: + try: + model_path = pann_get_ckpt_path(model_name_or_path) + except ValueError: + raise ValueError( + f"Invalid argument {model_name_or_path=}. (expected a path to a checkpoint file or a model name in {tuple(PANN_PRETRAINED_URLS.keys())})" + ) + + if not osp.isfile(model_path): + if offline: + raise FileNotFoundError( + f"Cannot find checkpoint model file in '{model_path}' with mode {offline=}." + ) + else: + pann_download_ckpt(model_name_or_path, verbose) + + del model_name_or_path + + data = torch.load(model_path, map_location=device) + state_dict = data["model"] + + if verbose >= 1: + test_map = data.get("test_mAP", "unknown") + pylog.info( + f"Loading encoder weights from '{model_path}'... (with test_mAP={test_map})" + ) + + return state_dict + + +def pann_download_ckpt(model_name: str, verbose: int = 0) -> None: + """Download PANN checkpoint file.""" + fpath = pann_get_ckpt_path(model_name) + url = PANN_PRETRAINED_URLS[model_name]["url"] + torch.hub.download_url_to_file(url, fpath, progress=verbose >= 1) diff --git a/src/conette/nn/pann_utils/hub.py b/src/conette/nn/pann_utils/hub.py new file mode 100644 index 000000000..43b5e6971 --- /dev/null +++ b/src/conette/nn/pann_utils/hub.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Any, Optional, Union + +import torch + +from torch import nn + +from conette.nn.pann_utils import models +from conette.nn.pann_utils.ckpt import ( + PANN_PRETRAINED_URLS, + pann_load_state_dict, +) + + +def build_pann_model( + model_name: str, + pretrained: bool = True, + model_kwargs: Optional[dict[str, Any]] = None, + device: Union[str, torch.device, None] = "auto", + offline: bool = False, + verbose: int = 0, + strict_load: bool = True, +) -> nn.Module: + """Build pretrained PANN model from name. + + :param model_name: PANN model name. (case sensitive) + :param pretrained: If True, load pretrained weights. defaults to True. + :param model_kwargs: Optional keywords arguments passed to PANN model initializer. defaults to None. + :param device: Output device of the model. defaults to "auto". + :param offline: If True, disable automatic checkpoint downloading. defaults to False. + :param verbose: Verbose level during model build. defaults to 0. + :param strict_load: If True, check if checkpoint entirely corresponds to the initialized model. defaults to True. + :returns: The PANN model built as nn.Module. + """ + if model_name not in PANN_PRETRAINED_URLS: + raise ValueError( + f"Invalid argument {model_name=}. (expected one of {tuple(PANN_PRETRAINED_URLS.keys())})" + ) + + if model_kwargs is None: + model_kwargs = {} + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + classpath = f"{models.__name__}.{model_name}" + classtype = eval(classpath) + model: nn.Module = classtype(**model_kwargs) + + if pretrained: + state_dict = pann_load_state_dict( + model_name_or_path=model_name, + offline=offline, + verbose=verbose, + ) + model.load_state_dict(state_dict, strict=strict_load) + + model = model.to(device=device) + return model diff --git a/src/conette/nn/pann_utils/models.py b/src/conette/nn/pann_utils/models.py new file mode 100644 index 000000000..e2d4f6302 --- /dev/null +++ b/src/conette/nn/pann_utils/models.py @@ -0,0 +1,3956 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import torch + +from torch import nn, Tensor +from torch.nn import functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from conette.nn.pann_utils.pytorch_utils import ( + do_mixup, + interpolate, + pad_framewise_output, +) + + +def init_layer(layer) -> None: + """Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_bn(bn) -> None: + """Initialize a Batchnorm layer.""" + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels) -> None: + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input: Tensor, pool_size=(2, 2), pool_type="avg"): + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), + stride=(1, 1), + padding=(2, 2), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation="linear", temperature=1.0): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + self.cla = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == "linear": + return x + elif self.activation == "sigmoid": + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_no_specaug(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_no_specaug, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_no_dropout(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_no_dropout, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn6(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn6, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn10(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn10, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +def _resnet_conv3x3(in_planes, out_planes): + # 3x3 convolution with padding + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + dilation=1, + ) + + +def _resnet_conv1x1(in_planes, out_planes): + # 1x1 convolution + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) + + +class _ResnetBasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(_ResnetBasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + "_ResnetBasicBlock only supports groups=1 and base_width=64" + ) + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in _ResnetBasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + + self.stride = stride + + self.conv1 = _resnet_conv3x3(inplanes, planes) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = _resnet_conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + self.init_weights() + + def init_weights(self): + init_layer(self.conv1) + init_bn(self.bn1) + init_layer(self.conv2) + init_bn(self.bn2) + nn.init.constant_(self.bn2.weight, 0) + + def forward(self, x): + identity = x + + if self.stride == 2: + out = F.avg_pool2d(x, kernel_size=(2, 2)) + else: + out = x + + out = self.conv1(out) + out = self.bn1(out) + out = self.relu(out) + out = F.dropout(out, p=0.1, training=self.training) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(identity) + + out += identity + out = self.relu(out) + + return out + + +class _ResnetBottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(_ResnetBottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + self.stride = stride + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = _resnet_conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = _resnet_conv3x3(width, width) + self.bn2 = norm_layer(width) + self.conv3 = _resnet_conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + self.init_weights() + + def init_weights(self): + init_layer(self.conv1) + init_bn(self.bn1) + init_layer(self.conv2) + init_bn(self.bn2) + init_layer(self.conv3) + init_bn(self.bn3) + nn.init.constant_(self.bn3.weight, 0) + + def forward(self, x): + identity = x + + if self.stride == 2: + x = F.avg_pool2d(x, kernel_size=(2, 2)) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = F.dropout(out, p=0.1, training=self.training) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(identity) + + out += identity + out = self.relu(out) + + return out + + +class _ResNet(nn.Module): + def __init__( + self, + block, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + super(_ResNet, self).__init__() + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + + self.layer1 = self._make_layer(block, 64, layers[0], stride=1) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + if stride == 1: + downsample = nn.Sequential( + _resnet_conv1x1(self.inplanes, planes * block.expansion), + norm_layer(planes * block.expansion), + ) + init_layer(downsample[0]) + init_bn(downsample[1]) + elif stride == 2: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2), + _resnet_conv1x1(self.inplanes, planes * block.expansion), + norm_layer(planes * block.expansion), + ) + init_layer(downsample[1]) + init_bn(downsample[2]) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + +class ResNet22(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(ResNet22, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + # self.conv_block2 = ConvBlock(in_channels=64, out_channels=64) + + self.resnet = _ResNet( + block=_ResnetBasicBlock, layers=[2, 2, 2, 2], zero_init_residual=True + ) + + self.conv_block_after1 = ConvBlock(in_channels=512, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = self.resnet(x) + x = F.avg_pool2d(x, kernel_size=(2, 2)) + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = self.conv_block_after1(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class ResNet38(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(ResNet38, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + # self.conv_block2 = ConvBlock(in_channels=64, out_channels=64) + + self.resnet = _ResNet( + block=_ResnetBasicBlock, layers=[3, 4, 6, 3], zero_init_residual=True + ) + + self.conv_block_after1 = ConvBlock(in_channels=512, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = self.resnet(x) + x = F.avg_pool2d(x, kernel_size=(2, 2)) + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = self.conv_block_after1(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class ResNet54(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(ResNet54, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + # self.conv_block2 = ConvBlock(in_channels=64, out_channels=64) + + self.resnet = _ResNet( + block=_ResnetBottleneck, layers=[3, 4, 6, 3], zero_init_residual=True + ) + + self.conv_block_after1 = ConvBlock(in_channels=2048, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = self.resnet(x) + x = F.avg_pool2d(x, kernel_size=(2, 2)) + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = self.conv_block_after1(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training, inplace=True) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_emb512(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_emb512, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_emb128(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_emb128, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 128, bias=True) + self.fc_audioset = nn.Linear(128, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_emb32(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_emb32, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 32, bias=True) + self.fc_audioset = nn.Linear(32, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class MobileNetV1(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(MobileNetV1, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + def conv_bn(inp, oup, stride): + _layers = [ + nn.Conv2d(inp, oup, 3, 1, 1, bias=False), + nn.AvgPool2d(stride), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ] + _layers = nn.Sequential(*_layers) + init_layer(_layers[0]) + init_bn(_layers[2]) + return _layers + + def conv_dw(inp, oup, stride): + _layers = [ + nn.Conv2d(inp, inp, 3, 1, 1, groups=inp, bias=False), + nn.AvgPool2d(stride), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ] + _layers = nn.Sequential(*_layers) + init_layer(_layers[0]) + init_bn(_layers[2]) + init_layer(_layers[4]) + init_bn(_layers[5]) + return _layers + + self.features = nn.Sequential( + conv_bn(1, 32, 2), + conv_dw(32, 64, 1), + conv_dw(64, 128, 2), + conv_dw(128, 128, 1), + conv_dw(128, 256, 2), + conv_dw(256, 256, 1), + conv_dw(256, 512, 2), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 1024, 2), + conv_dw(1024, 1024, 1), + ) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.features(x) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + _layers = [ + nn.Conv2d( + hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False + ), + nn.AvgPool2d(stride), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ] + _layers = nn.Sequential(*_layers) + init_layer(_layers[0]) + init_bn(_layers[2]) + init_layer(_layers[4]) + init_bn(_layers[5]) + self.conv = _layers + else: + _layers = [ + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + nn.Conv2d( + hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False + ), + nn.AvgPool2d(stride), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ] + _layers = nn.Sequential(*_layers) + init_layer(_layers[0]) + init_bn(_layers[1]) + init_layer(_layers[3]) + init_bn(_layers[5]) + init_layer(_layers[7]) + init_bn(_layers[8]) + self.conv = _layers + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(MobileNetV2, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + width_mult = 1.0 + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 2], + [6, 160, 3, 1], + [6, 320, 1, 1], + ] + + def conv_bn(inp, oup, stride): + _layers = [ + nn.Conv2d(inp, oup, 3, 1, 1, bias=False), + nn.AvgPool2d(stride), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True), + ] + _layers = nn.Sequential(*_layers) + init_layer(_layers[0]) + init_bn(_layers[2]) + return _layers + + def conv_1x1_bn(inp, oup): + _layers = nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True), + ) + init_layer(_layers[0]) + init_bn(_layers[1]) + return _layers + + # building first layer + input_channel = int(input_channel * width_mult) + self.last_channel = ( + int(last_channel * width_mult) if width_mult > 1.0 else last_channel + ) + features_lst: list[nn.Module] = [conv_bn(1, input_channel, 2)] + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + if i == 0: + features_lst.append( + block(input_channel, output_channel, s, expand_ratio=t) + ) + else: + features_lst.append( + block(input_channel, output_channel, 1, expand_ratio=t) + ) + input_channel = output_channel + # building last several layers + features_lst.append(conv_1x1_bn(input_channel, self.last_channel)) + # make it nn.Sequential + self.features = nn.Sequential(*features_lst) + + self.fc1 = nn.Linear(1280, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.features(x) + + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + # x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class LeeNetConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride): + super(LeeNetConvBlock, self).__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + bias=False, + ) + + self.bn1 = nn.BatchNorm1d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + def forward(self, x, pool_size=1): + x = F.relu_(self.bn1(self.conv1(x))) + if pool_size != 1: + x = F.max_pool1d(x, kernel_size=pool_size, padding=pool_size // 2) + return x + + +class LeeNet11(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(LeeNet11, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.conv_block1 = LeeNetConvBlock(1, 64, 3, 3) + self.conv_block2 = LeeNetConvBlock(64, 64, 3, 1) + self.conv_block3 = LeeNetConvBlock(64, 64, 3, 1) + self.conv_block4 = LeeNetConvBlock(64, 128, 3, 1) + self.conv_block5 = LeeNetConvBlock(128, 128, 3, 1) + self.conv_block6 = LeeNetConvBlock(128, 128, 3, 1) + self.conv_block7 = LeeNetConvBlock(128, 128, 3, 1) + self.conv_block8 = LeeNetConvBlock(128, 128, 3, 1) + self.conv_block9 = LeeNetConvBlock(128, 256, 3, 1) + + self.fc1 = nn.Linear(256, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = input[:, None, :] + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x) + x = self.conv_block2(x, pool_size=3) + x = self.conv_block3(x, pool_size=3) + x = self.conv_block4(x, pool_size=3) + x = self.conv_block5(x, pool_size=3) + x = self.conv_block6(x, pool_size=3) + x = self.conv_block7(x, pool_size=3) + x = self.conv_block8(x, pool_size=3) + x = self.conv_block9(x, pool_size=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class LeeNetConvBlock2(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride): + super(LeeNetConvBlock2, self).__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + bias=False, + ) + + self.conv2 = nn.Conv1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=False, + ) + + self.bn1 = nn.BatchNorm1d(out_channels) + self.bn2 = nn.BatchNorm1d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, x, pool_size=1): + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_size != 1: + x = F.max_pool1d(x, kernel_size=pool_size, padding=pool_size // 2) + return x + + +class LeeNet24(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(LeeNet24, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.conv_block1 = LeeNetConvBlock2(1, 64, 3, 3) + self.conv_block2 = LeeNetConvBlock2(64, 96, 3, 1) + self.conv_block3 = LeeNetConvBlock2(96, 128, 3, 1) + self.conv_block4 = LeeNetConvBlock2(128, 128, 3, 1) + self.conv_block5 = LeeNetConvBlock2(128, 256, 3, 1) + self.conv_block6 = LeeNetConvBlock2(256, 256, 3, 1) + self.conv_block7 = LeeNetConvBlock2(256, 512, 3, 1) + self.conv_block8 = LeeNetConvBlock2(512, 512, 3, 1) + self.conv_block9 = LeeNetConvBlock2(512, 1024, 3, 1) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = input[:, None, :] + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block2(x, pool_size=3) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block3(x, pool_size=3) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block4(x, pool_size=3) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block5(x, pool_size=3) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block6(x, pool_size=3) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block7(x, pool_size=3) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block8(x, pool_size=3) + x = F.dropout(x, p=0.1, training=self.training) + x = self.conv_block9(x, pool_size=1) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class DaiNetResBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size): + super(DaiNetResBlock, self).__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=False, + ) + + self.conv2 = nn.Conv1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=False, + ) + + self.conv3 = nn.Conv1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=False, + ) + + self.conv4 = nn.Conv1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=False, + ) + + self.downsample = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + + self.bn1 = nn.BatchNorm1d(out_channels) + self.bn2 = nn.BatchNorm1d(out_channels) + self.bn3 = nn.BatchNorm1d(out_channels) + self.bn4 = nn.BatchNorm1d(out_channels) + self.bn_downsample = nn.BatchNorm1d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_layer(self.conv3) + init_layer(self.conv4) + init_layer(self.downsample) + init_bn(self.bn1) + init_bn(self.bn2) + init_bn(self.bn3) + init_bn(self.bn4) + nn.init.constant_(self.bn4.weight, 0) + init_bn(self.bn_downsample) + + def forward(self, input, pool_size=1): + x = F.relu_(self.bn1(self.conv1(input))) + x = F.relu_(self.bn2(self.conv2(x))) + x = F.relu_(self.bn3(self.conv3(x))) + x = self.bn4(self.conv4(x)) + if input.shape == x.shape: + x = F.relu_(x + input) + else: + x = F.relu(x + self.bn_downsample(self.downsample(input))) + + if pool_size != 1: + x = F.max_pool1d(x, kernel_size=pool_size, padding=pool_size // 2) + return x + + +class DaiNet19(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(DaiNet19, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.conv0 = nn.Conv1d( + in_channels=1, + out_channels=64, + kernel_size=80, + stride=4, + padding=0, + bias=False, + ) + self.bn0 = nn.BatchNorm1d(64) + self.conv_block1 = DaiNetResBlock(64, 64, 3) + self.conv_block2 = DaiNetResBlock(64, 128, 3) + self.conv_block3 = DaiNetResBlock(128, 256, 3) + self.conv_block4 = DaiNetResBlock(256, 512, 3) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv0) + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = input[:, None, :] + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.bn0(self.conv0(x)) + x = self.conv_block1(x) + x = F.max_pool1d(x, kernel_size=4) + x = self.conv_block2(x) + x = F.max_pool1d(x, kernel_size=4) + x = self.conv_block3(x) + x = F.max_pool1d(x, kernel_size=4) + x = self.conv_block4(x) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +def _resnet_conv3x1_wav1d(in_planes, out_planes, dilation): + # 3x3 convolution with padding + return nn.Conv1d( + in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=dilation, + groups=1, + bias=False, + dilation=dilation, + ) + + +def _resnet_conv1x1_wav1d(in_planes, out_planes): + # 1x1 convolution + return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) + + +class _ResnetBasicBlockWav1d(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(_ResnetBasicBlockWav1d, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm1d + if groups != 1 or base_width != 64: + raise ValueError( + "_ResnetBasicBlock only supports groups=1 and base_width=64" + ) + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in _ResnetBasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + + self.stride = stride + + self.conv1 = _resnet_conv3x1_wav1d(inplanes, planes, dilation=1) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = _resnet_conv3x1_wav1d(planes, planes, dilation=2) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + self.init_weights() + + def init_weights(self): + init_layer(self.conv1) + init_bn(self.bn1) + init_layer(self.conv2) + init_bn(self.bn2) + nn.init.constant_(self.bn2.weight, 0) + + def forward(self, x): + identity = x + + if self.stride != 1: + out = F.max_pool1d(x, kernel_size=self.stride) + else: + out = x + + out = self.conv1(out) + out = self.bn1(out) + out = self.relu(out) + out = F.dropout(out, p=0.1, training=self.training) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(identity) + + out += identity + out = self.relu(out) + + return out + + +class _ResNetWav1d(nn.Module): + def __init__( + self, + block, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + super(_ResNetWav1d, self).__init__() + + if norm_layer is None: + norm_layer = nn.BatchNorm1d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + + self.layer1 = self._make_layer(block, 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, 128, layers[1], stride=4) + self.layer3 = self._make_layer(block, 256, layers[2], stride=4) + self.layer4 = self._make_layer(block, 512, layers[3], stride=4) + self.layer5 = self._make_layer(block, 1024, layers[4], stride=4) + self.layer6 = self._make_layer(block, 1024, layers[5], stride=4) + self.layer7 = self._make_layer(block, 2048, layers[6], stride=4) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + if stride == 1: + downsample = nn.Sequential( + _resnet_conv1x1_wav1d(self.inplanes, planes * block.expansion), + norm_layer(planes * block.expansion), + ) + init_layer(downsample[0]) + init_bn(downsample[1]) + else: + downsample = nn.Sequential( + nn.AvgPool1d(kernel_size=stride), + _resnet_conv1x1_wav1d(self.inplanes, planes * block.expansion), + norm_layer(planes * block.expansion), + ) + init_layer(downsample[1]) + init_bn(downsample[2]) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + x = self.layer7(x) + + return x + + +class Res1dNet31(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Res1dNet31, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.conv0 = nn.Conv1d( + in_channels=1, + out_channels=64, + kernel_size=11, + stride=5, + padding=5, + bias=False, + ) + self.bn0 = nn.BatchNorm1d(64) + + self.resnet = _ResNetWav1d(_ResnetBasicBlockWav1d, [2, 2, 2, 2, 2, 2, 2]) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv0) + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = input[:, None, :] + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.bn0(self.conv0(x)) + x = self.resnet(x) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Res1dNet51(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Res1dNet51, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.conv0 = nn.Conv1d( + in_channels=1, + out_channels=64, + kernel_size=11, + stride=5, + padding=5, + bias=False, + ) + self.bn0 = nn.BatchNorm1d(64) + + self.resnet = _ResNetWav1d(_ResnetBasicBlockWav1d, [2, 3, 4, 6, 4, 3, 2]) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv0) + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = input[:, None, :] + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.bn0(self.conv0(x)) + x = self.resnet(x) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class ConvPreWavBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvPreWavBlock, self).__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + + self.conv2 = nn.Conv1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + dilation=2, + padding=2, + bias=False, + ) + + self.bn1 = nn.BatchNorm1d(out_channels) + self.bn2 = nn.BatchNorm1d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input, pool_size): + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + x = F.max_pool1d(x, kernel_size=pool_size) + + return x + + +class Wavegram_Cnn14(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Wavegram_Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.pre_conv0 = nn.Conv1d( + in_channels=1, + out_channels=64, + kernel_size=11, + stride=5, + padding=5, + bias=False, + ) + self.pre_bn0 = nn.BatchNorm1d(64) + self.pre_block1 = ConvPreWavBlock(64, 64) + self.pre_block2 = ConvPreWavBlock(64, 128) + self.pre_block3 = ConvPreWavBlock(128, 128) + self.pre_block4 = ConvBlock(in_channels=4, out_channels=64) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.pre_conv0) + init_bn(self.pre_bn0) + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + # Wavegram + a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :]))) + a1 = self.pre_block1(a1, pool_size=4) + a1 = self.pre_block2(a1, pool_size=4) + a1 = self.pre_block3(a1, pool_size=4) + a1 = a1.reshape((a1.shape[0], -1, 32, a1.shape[-1])).transpose(2, 3) + a1 = self.pre_block4(a1, pool_size=(2, 1)) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + a1 = do_mixup(a1, mixup_lambda) + + x = a1 + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Wavegram_Logmel_Cnn14(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Wavegram_Logmel_Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.pre_conv0 = nn.Conv1d( + in_channels=1, + out_channels=64, + kernel_size=11, + stride=5, + padding=5, + bias=False, + ) + self.pre_bn0 = nn.BatchNorm1d(64) + self.pre_block1 = ConvPreWavBlock(64, 64) + self.pre_block2 = ConvPreWavBlock(64, 128) + self.pre_block3 = ConvPreWavBlock(128, 128) + self.pre_block4 = ConvBlock(in_channels=4, out_channels=64) + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=128, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.pre_conv0) + init_bn(self.pre_bn0) + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + # Wavegram + a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :]))) + a1 = self.pre_block1(a1, pool_size=4) + a1 = self.pre_block2(a1, pool_size=4) + a1 = self.pre_block3(a1, pool_size=4) + a1 = a1.reshape((a1.shape[0], -1, 32, a1.shape[-1])).transpose(2, 3) + a1 = self.pre_block4(a1, pool_size=(2, 1)) + + # Log mel spectrogram + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + a1 = do_mixup(a1, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + + # Concatenate Wavegram and Log mel spectrogram along the channel dimension + x = torch.cat((x, a1), dim=1) + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Wavegram_Logmel128_Cnn14(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Wavegram_Logmel128_Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.pre_conv0 = nn.Conv1d( + in_channels=1, + out_channels=64, + kernel_size=11, + stride=5, + padding=5, + bias=False, + ) + self.pre_bn0 = nn.BatchNorm1d(64) + self.pre_block1 = ConvPreWavBlock(64, 64) + self.pre_block2 = ConvPreWavBlock(64, 128) + self.pre_block3 = ConvPreWavBlock(128, 256) + self.pre_block4 = ConvBlock(in_channels=4, out_channels=64) + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=16, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(128) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=128, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_layer(self.pre_conv0) + init_bn(self.pre_bn0) + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + # Wavegram + a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :]))) + a1 = self.pre_block1(a1, pool_size=4) + a1 = self.pre_block2(a1, pool_size=4) + a1 = self.pre_block3(a1, pool_size=4) + a1 = a1.reshape((a1.shape[0], -1, 64, a1.shape[-1])).transpose(2, 3) + a1 = self.pre_block4(a1, pool_size=(2, 1)) + + # Log mel spectrogram + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + a1 = do_mixup(a1, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + + # Concatenate Wavegram and Log mel spectrogram along the channel dimension + x = torch.cat((x, a1), dim=1) + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_16k(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_16k, self).__init__() + + assert sample_rate == 16000 + assert window_size == 512 + assert hop_size == 160 + assert mel_bins == 64 + assert fmin == 50 + assert fmax == 8000 + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_8k(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_8k, self).__init__() + + assert sample_rate == 8000 + assert window_size == 256 + assert hop_size == 80 + assert mel_bins == 64 + assert fmin == 50 + assert fmax == 4000 + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_mixup_time_domain(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_mixup_time_domain, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = input + + # Mixup in time domain + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_mel32(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_mel32, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=4, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(32) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +class Cnn14_mel128(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_mel128, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=16, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(128) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {"clipwise_output": clipwise_output, "embedding": embedding} + + return output_dict + + +############ +class Cnn14_DecisionLevelMax(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_DecisionLevelMax, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + frames_num = x.shape[2] + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = x.transpose(1, 2) + x = F.relu_(self.fc1(x)) + x = F.dropout(x, p=0.5, training=self.training) + segmentwise_output = torch.sigmoid(self.fc_audioset(x)) + (clipwise_output, _) = torch.max(segmentwise_output, dim=1) + + # Get framewise output + framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) + framewise_output = pad_framewise_output(framewise_output, frames_num) + + output_dict = { + "framewise_output": framewise_output, + "clipwise_output": clipwise_output, + } + + return output_dict + + +class Cnn14_DecisionLevelAvg(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_DecisionLevelAvg, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + frames_num = x.shape[2] + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = x.transpose(1, 2) + x = F.relu_(self.fc1(x)) + x = F.dropout(x, p=0.5, training=self.training) + segmentwise_output = torch.sigmoid(self.fc_audioset(x)) + clipwise_output = torch.mean(segmentwise_output, dim=1) + + # Get framewise output + framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) + framewise_output = pad_framewise_output(framewise_output, frames_num) + + # Get framewise output + framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) + framewise_output = pad_framewise_output(framewise_output, frames_num) + + output_dict = { + "framewise_output": framewise_output, + "clipwise_output": clipwise_output, + } + + return output_dict + + +class Cnn14_DecisionLevelAtt(nn.Module): + def __init__( + self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num + ): + super(Cnn14_DecisionLevelAtt, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.att_block = AttBlock(2048, classes_num, activation="sigmoid") + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + frames_num = x.shape[2] + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = x.transpose(1, 2) + x = F.relu_(self.fc1(x)) + x = x.transpose(1, 2) + x = F.dropout(x, p=0.5, training=self.training) + (clipwise_output, _, segmentwise_output) = self.att_block(x) + segmentwise_output = segmentwise_output.transpose(1, 2) + + # Get framewise output + framewise_output = interpolate(segmentwise_output, self.interpolate_ratio) + framewise_output = pad_framewise_output(framewise_output, frames_num) + + output_dict = { + "framewise_output": framewise_output, + "clipwise_output": clipwise_output, + } + + return output_dict diff --git a/src/conette/nn/pann_utils/pytorch_utils.py b/src/conette/nn/pann_utils/pytorch_utils.py new file mode 100644 index 000000000..5a7bff523 --- /dev/null +++ b/src/conette/nn/pann_utils/pytorch_utils.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import time + +import numpy as np +import torch + +from torch import nn + + +def move_data_to_device(x, device): + if "float" in str(x.dtype): + x = torch.Tensor(x) + elif "int" in str(x.dtype): + x = torch.LongTensor(x) + else: + return x + + return x.to(device) + + +def do_mixup(x, mixup_lambda): + """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes + (1, 3, 5, ...). + + Args: + x: (batch_size * 2, ...) + mixup_lambda: (batch_size * 2,) + + Returns: + out: (batch_size, ...) + """ + out = ( + x[0::2].transpose(0, -1) * mixup_lambda[0::2] + + x[1::2].transpose(0, -1) * mixup_lambda[1::2] + ).transpose(0, -1) + return out + + +def append_to_dict(dict, key, value): + if key in dict.keys(): + dict[key].append(value) + else: + dict[key] = [value] + + +def forward(model, generator, return_input=False, return_target=False): + """Forward data to a model. + + Args: + model: object + generator: object + return_input: bool + return_target: bool + + Returns: + audio_name: (audios_num,) + clipwise_output: (audios_num, classes_num) + (ifexist) segmentwise_output: (audios_num, segments_num, classes_num) + (ifexist) framewise_output: (audios_num, frames_num, classes_num) + (optional) return_input: (audios_num, segment_samples) + (optional) return_target: (audios_num, classes_num) + """ + output_dict = {} + device = next(model.parameters()).device + time1 = time.time() + + # Forward data to a model in mini-batches + for n, batch_data_dict in enumerate(generator): + print(n) + batch_waveform = move_data_to_device(batch_data_dict["waveform"], device) + + with torch.no_grad(): + model.eval() + batch_output = model(batch_waveform) + + append_to_dict(output_dict, "audio_name", batch_data_dict["audio_name"]) + + append_to_dict( + output_dict, + "clipwise_output", + batch_output["clipwise_output"].data.cpu().numpy(), + ) + + if "segmentwise_output" in batch_output.keys(): + append_to_dict( + output_dict, + "segmentwise_output", + batch_output["segmentwise_output"].data.cpu().numpy(), + ) + + if "framewise_output" in batch_output.keys(): + append_to_dict( + output_dict, + "framewise_output", + batch_output["framewise_output"].data.cpu().numpy(), + ) + + if return_input: + append_to_dict(output_dict, "waveform", batch_data_dict["waveform"]) + + if return_target: + if "target" in batch_data_dict.keys(): + append_to_dict(output_dict, "target", batch_data_dict["target"]) + + if n % 10 == 0: + print( + " --- Inference time: {:.3f} s / 10 iterations ---".format( + time.time() - time1 + ) + ) + time1 = time.time() + + for key in output_dict.keys(): + output_dict[key] = np.concatenate(output_dict[key], axis=0) + + return output_dict + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1:, :].repeat( + 1, frames_num - framewise_output.shape[1], 1 + ) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + return output + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def count_flops(model, audio_length): + """Count flops. Code modified from others' implementation.""" + multiply_adds = True + list_conv2d = [] + + def conv2d_hook(self, input, output): + batch_size, input_channels, input_height, input_width = input[0].size() + output_channels, output_height, output_width = output[0].size() + + kernel_ops = ( + self.kernel_size[0] + * self.kernel_size[1] + * (self.in_channels / self.groups) + * (2 if multiply_adds else 1) + ) + bias_ops = 1 if self.bias is not None else 0 + + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_height * output_width + + list_conv2d.append(flops) + + list_conv1d = [] + + def conv1d_hook(self, input, output): + batch_size, input_channels, input_length = input[0].size() + output_channels, output_length = output[0].size() + + kernel_ops = ( + self.kernel_size[0] + * (self.in_channels / self.groups) + * (2 if multiply_adds else 1) + ) + bias_ops = 1 if self.bias is not None else 0 + + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_length + + list_conv1d.append(flops) + + list_linear = [] + + def linear_hook(self, input, output): + batch_size = input[0].size(0) if input[0].dim() == 2 else 1 + + weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) + bias_ops = self.bias.nelement() + + flops = batch_size * (weight_ops + bias_ops) + list_linear.append(flops) + + list_bn = [] + + def bn_hook(self, input, output): + list_bn.append(input[0].nelement() * 2) + + list_relu = [] + + def relu_hook(self, input, output): + list_relu.append(input[0].nelement() * 2) + + list_pooling2d = [] + + def pooling2d_hook(self, input, output): + batch_size, input_channels, input_height, input_width = input[0].size() + output_channels, output_height, output_width = output[0].size() + + kernel_ops = self.kernel_size * self.kernel_size + bias_ops = 0 + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_height * output_width + + list_pooling2d.append(flops) + + list_pooling1d = [] + + def pooling1d_hook(self, input, output): + batch_size, input_channels, input_length = input[0].size() + output_channels, output_length = output[0].size() + + kernel_ops = self.kernel_size[0] + bias_ops = 0 + + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_length + + list_pooling2d.append(flops) + + def foo(net): + childrens = list(net.children()) + if not childrens: + if isinstance(net, nn.Conv2d): + net.register_forward_hook(conv2d_hook) + elif isinstance(net, nn.Conv1d): + net.register_forward_hook(conv1d_hook) + elif isinstance(net, nn.Linear): + net.register_forward_hook(linear_hook) + elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): + net.register_forward_hook(bn_hook) + elif isinstance(net, nn.ReLU): + net.register_forward_hook(relu_hook) + elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): + net.register_forward_hook(pooling2d_hook) + elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): + net.register_forward_hook(pooling1d_hook) + else: + print("Warning: flop of module {} is not counted!".format(net)) + return + for c in childrens: + foo(c) + + # Register hook + foo(model) + + device = device = next(model.parameters()).device + input = torch.rand(1, audio_length).to(device) + + out = model(input) + + total_flops = ( + sum(list_conv2d) + + sum(list_conv1d) + + sum(list_linear) + + sum(list_bn) + + sum(list_relu) + + sum(list_pooling2d) + + sum(list_pooling1d) + ) + + return total_flops diff --git a/src/conette/pl_modules/base.py b/src/conette/pl_modules/base.py index 7efabc9ac..c2c938fc8 100644 --- a/src/conette/pl_modules/base.py +++ b/src/conette/pl_modules/base.py @@ -255,19 +255,15 @@ def attach_example(self) -> None: self.example_input_array = self.get_example() def encode_text(self, *args, **kwargs) -> Any: - assert isinstance(self.tokenizer, AACTokenizer) return self.tokenizer.encode_rec(*args, **kwargs) def tokenize_text(self, *args, **kwargs) -> Any: - assert isinstance(self.tokenizer, AACTokenizer) return self.tokenizer.tokenize_rec(*args, **kwargs) def decode_text(self, *args, **kwargs) -> Any: - assert isinstance(self.tokenizer, AACTokenizer) return self.tokenizer.decode_rec(*args, **kwargs) def detokenize_text(self, *args, **kwargs) -> Any: - assert isinstance(self.tokenizer, AACTokenizer) return self.tokenizer.detokenize_rec(*args, **kwargs) def csum_module(self, only_trainable: bool = False) -> int: diff --git a/src/conette/pl_modules/baseline.py b/src/conette/pl_modules/baseline.py new file mode 100644 index 000000000..84da1993e --- /dev/null +++ b/src/conette/pl_modules/baseline.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Any, Optional + +import torch + +from torch import nn, Tensor + +from conette.nn.decoders.aac_tfmer import AACTransformerDecoder +from conette.nn.decoding.beam import generate +from conette.nn.decoding.forcing import teacher_forcing +from conette.nn.decoding.greedy import greedy_search +from conette.nn.encoders.ident import FrameIdentEncoder +from conette.nn.functional.indexes import randperm_diff +from conette.nn.functional.mask import ( + lengths_to_pad_mask, + tensor_to_pad_mask, +) +from conette.nn.loss.ce_mean import CrossEntropyLossMean +from conette.pl_modules.base import AACLightningModule +from conette.pl_modules.common import ( + build_proj_lin, + get_forbid_rep_mask, + TrainBatch, + ValBatch, + TestBatch, +) +from conette.tokenization.aac_tokenizer import AACTokenizer +from conette.transforms.mixup import sample_lambda + + +pylog = logging.getLogger(__name__) + + +class BaselinePLM(AACLightningModule): + def __init__( + self, + # Model params + label_smoothing: float = 0.1, + gen_val_cands: str = "generate", + mixup_alpha: float = 0.4, + # Encoder params + proj_name: str = "lin768", + # Decoder params + min_pred_size: int = 3, + max_pred_size: Optional[int] = None, + beam_size: int = 10, + nhead: int = 8, + d_model: int = 256, + num_decoder_layers: int = 6, + decoder_dropout_p: float = 0.2, + dim_feedforward: int = 2048, + acti_name: str = "gelu", + # Optimizer params + optim_name: str = "AdamW", + lr: float = 5e-4, + weight_decay: float = 2.0, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + use_custom_wd: bool = True, + # Scheduler params + sched_name: str = "cos_decay", + sched_n_steps: Optional[int] = None, + sched_interval: str = "epoch", + sched_freq: int = 1, + # Other params + train_tokenizer: Optional[AACTokenizer] = None, + verbose: int = 0, + ) -> None: + super().__init__(train_tokenizer) + + self.train_criterion: nn.Module = nn.Identity() + self.encoder: nn.Module = nn.Identity() + self.decoder: AACTransformerDecoder = None # type: ignore + self.projection: nn.Module = nn.Identity() + + self.save_hyperparameters(ignore=("train_tokenizer",)) + + if self.tokenizer.is_fit(): + self.setup() + + # --- Setup methods + def build_model(self) -> None: + if self.is_built(): + raise RuntimeError("Cannot build model twice.") + + if not self.tokenizer.is_fit(): + raise RuntimeError( + f"AACTokenizer is not fit for {self.__class__.__name__}." + ) + + tok_max_sent_size = self.tokenizer.get_max_sentence_size() + if self.hp.max_pred_size is None: + self.hparams.max_pred_size = tok_max_sent_size # type: ignore + if self.hp.verbose >= 1: + pylog.info(f"Auto-detect value {self.hp.max_pred_size=}.") + else: + if self.hp.verbose >= 1: + pylog.info( + f"Set {self.hp.max_pred_size=}. (with tokenizer max={tok_max_sent_size})" + ) + + self.train_criterion = nn.CrossEntropyLoss( + ignore_index=self.pad_id, + label_smoothing=self.hp.label_smoothing, + ) + self.val_criterion = CrossEntropyLossMean(ignore_index=self.pad_id, dim=1) + self.encoder = FrameIdentEncoder() + + if self.hp.proj_name == "lin2048": + self.projection = build_proj_lin(2048, self.hp.d_model, False) + elif self.hp.proj_name == "lin768": + self.projection = build_proj_lin(768, self.hp.d_model, False) + else: + raise ValueError(f"Invalid argument {self.hp.proj_name=}.") + + self.decoder = AACTransformerDecoder( + vocab_size=self.tokenizer.get_vocab_size(), + bos_id=self.bos_id, + eos_id=self.eos_id, + pad_id=self.pad_id, + acti_name=self.hp.acti_name, + d_model=self.hp.d_model, + dim_feedforward=self.hp.dim_feedforward, + dropout=self.hp.decoder_dropout_p, + nhead=self.hp.nhead, + num_decoder_layers=self.hp.num_decoder_layers, + ) + + forbid_rep_mask = get_forbid_rep_mask( + "content_words", + self.tokenizer, + self.device, + self.hp.verbose, + ) + + self.forbid_rep_mask: Optional[Tensor] + self.register_buffer("forbid_rep_mask", forbid_rep_mask) + + def is_built(self) -> bool: + return self.decoder is not None + + # --- Train, val and test methods + def training_step(self, batch: TrainBatch, *args, **kwargs) -> Tensor: + audio = batch["audio"] + audio_shape = batch["audio_shape"] + captions = batch["captions"] + + bsize = captions.shape[0] + # if precomputed, audio: (bsize, n_channels=1, time_steps=31, emb_size=2048) + # captions : (bsize, max_cap_size) + + # Apply mixup on audio and input token embs + indexes = randperm_diff(bsize, device=self.device) + audio, audio_shape, lbd = self.mix_audio(audio, audio_shape, indexes) + + caps_in = captions[:, :-1] + caps_out = captions[:, 1:] + del captions + + caps_in_pad_mask = tensor_to_pad_mask(caps_in, pad_value=self.pad_id) + + caps_in = self.decoder.emb_layer(caps_in) + caps_in = caps_in * lbd + caps_in[indexes] * (1.0 - lbd) + + # Forward + encoder_outs = self.encode_audio(audio, audio_shape) + logits = self.decode_audio( + encoder_outs, + "forcing", + caps_in=caps_in, + caps_in_pad_mask=caps_in_pad_mask, + ) # note: use mixed prev tokens + + loss = self.train_criterion(logits, caps_out) # note: use unmixed target + + with torch.no_grad(): + scores = { + "loss": loss, + } + prefix = "train" + scores = {f"{prefix}/{k}": v for k, v in scores.items()} + self.log_dict( + scores, + batch_size=bsize, + ) + + return loss + + def validation_step(self, batch: ValBatch, *args, **kwargs) -> Any: + audio = batch["audio"] + audio_shape = batch["audio_shape"] + mult_captions = batch["mult_captions"] + bsize, n_caps_per_audio, _ = mult_captions.shape + + losses = torch.empty( + size=(bsize, n_caps_per_audio), + dtype=audio.dtype, + device=audio.device, + ) + + encoder_outs = self.encode_audio(audio, audio_shape) + + for i in range(n_caps_per_audio): + caps_in = mult_captions[:, i, :-1] + caps_out = mult_captions[:, i, 1:] + + # logits : (bsize, vocab_size, capt_len) + logits_i = self.decode_audio(encoder_outs, "forcing", caps_in=caps_in) + losses_i = self.val_criterion(logits_i, caps_out) + losses[:, i] = losses_i + + loss = losses.mean() + + if self.hp.gen_val_cands in ("none", None): + output = None + + elif self.hp.gen_val_cands in ("greedy", "generate"): + # Compute beam search results + outs = self.decode_audio(encoder_outs, self.hp.gen_val_cands) + if self.hp.gen_val_cands == "greedy": + preds = outs.argmax(dim=1) + else: + preds = outs[0] + + cands = self.decode_text(preds) + mrefs = batch["mult_references"] + output = { + f"cands_{self.hp.gen_val_cands}": cands, + "mrefs": mrefs, + } + else: + raise ValueError(f"Invalid argument {self.hp.gen_val_cands=}.") + + bar_scores = {"loss": loss} + non_bar_scores = {} + + prefix = "val" + bar_scores = {f"{prefix}/{k}": v for k, v in bar_scores.items()} + non_bar_scores = {f"{prefix}/{k}": v for k, v in non_bar_scores.items()} + + log_kwargs: dict[str, Any] = dict(batch_size=bsize) + self.log_dict(bar_scores, prog_bar=True, **log_kwargs) + self.log_dict(non_bar_scores, prog_bar=False, **log_kwargs) + + return output + + def test_step(self, batch: TestBatch, *args, **kwargs) -> dict[str, Any]: + audio = batch["audio"] + audio_shape = batch["audio_shape"] + mult_captions = batch["mult_captions"] + + bsize, n_caps_per_audio, _ = mult_captions.shape + encoder_outs = self.encode_audio(audio, audio_shape) + + # Compute test loss + losses = torch.empty( + size=(bsize, n_caps_per_audio), + dtype=audio.dtype, + device=audio.device, + ) + + for i in range(n_caps_per_audio): + caps_in = mult_captions[:, i, :-1] + caps_out = mult_captions[:, i, 1:] + logits_i = self.decode_audio(encoder_outs, "forcing", caps_in=caps_in) + losses_i = self.val_criterion(logits_i, caps_out) + losses[:, i] = losses_i + + loss = losses.mean() + + dataname = batch["dataset"][0] + subset = batch["subset"][0] + scores = { + f"test/{dataname}_{subset}.loss": loss, + } + self.log_dict(scores, batch_size=bsize) + + # Compute beam search results + preds, lprobs, mult_preds, mult_lprobs = self.decode_audio( + encoder_outs, "generate" + ) + outs = { + "losses": losses, + "preds": preds, + "lprobs": lprobs, + "mpreds": mult_preds, + "mlprobs": mult_lprobs, + } + + # Decode beam search results + keys = [key for key in outs.keys() if "preds" in key] + for key in keys: + cands_key = key.replace("preds", "cands") + + preds = outs[key] + cands = self.tokenizer.decode_rec(preds) + + outs[cands_key] = cands + + if "mult_references" in batch: + outs["mrefs"] = batch["mult_references"] + return outs + + def forward( + self, + batch: dict[str, Any], + decode_method: str = "generate", + **kwargs, + ) -> dict[str, Tensor]: + audio: Tensor = batch["audio"] + audio_shape: Tensor = batch["audio_shape"] + encoder_outs = self.encode_audio(audio, audio_shape) + if decode_method == "forcing" and "captions" in batch: + kwargs["caps_in"] = batch["captions"][:, :-1] + outs = self.decode_audio(encoder_outs, decode_method, **kwargs) + + if decode_method == "generate": + preds, lprobs, _mult_preds, _mult_lprobs = outs + cands = self.decode_text(preds) + return {"cands": cands, "preds": preds, "lprobs": lprobs} + else: + return outs + + # --- Other methods + def decode_audio( + self, + encoder_outs: dict[str, Tensor], + decode_method: str, + **kwargs, + ) -> Any: + if decode_method == "forcing": + if "caps_in" not in kwargs.keys(): + raise ValueError( + f"Please provide a 'caps_in' keyword argument with {decode_method=}. (found {tuple(kwargs.keys())})" + ) + forcing_hp: dict[str, Any] = { + "pad_id": self.pad_id, + "bos_id": self.bos_id, + "eos_id": self.eos_id, + "vocab_size": self.tokenizer.get_vocab_size(), + } + kwargs = forcing_hp | kwargs + outs = teacher_forcing( + self.decoder, + **encoder_outs, + **kwargs, + ) + elif decode_method == "greedy": + greedy_hp = { + "pad_id": self.pad_id, + "bos_id": self.bos_id, + "eos_id": self.eos_id, + "vocab_size": self.tokenizer.get_vocab_size(), + "min_pred_size": self.hp.min_pred_size, + "max_pred_size": self.hp.max_pred_size, + "forbid_rep_mask": self.forbid_rep_mask, + } + kwargs = greedy_hp | kwargs + outs = greedy_search( + self.decoder, + **encoder_outs, + **kwargs, + ) + + elif decode_method == "generate": + generate_hp = { + "pad_id": self.pad_id, + "bos_id": self.bos_id, + "eos_id": self.eos_id, + "vocab_size": self.tokenizer.get_vocab_size(), + "min_pred_size": self.hp.min_pred_size, + "max_pred_size": self.hp.max_pred_size, + "forbid_rep_mask": self.forbid_rep_mask, + "beam_size": self.hp.beam_size, + } + kwargs = generate_hp | kwargs + outs = generate( + self.decoder, + **encoder_outs, + **kwargs, + ) + else: + DECODE_METHODS = ("forcing", "greedy", "generate") + raise ValueError( + f"Unknown argument {decode_method=}. (expected one of {DECODE_METHODS})" + ) + return outs + + def encode_audio(self, audio: Tensor, audio_shape: Tensor) -> dict[str, Tensor]: + encoder_outs = self.encoder(audio, audio_shape) + frame_embs = encoder_outs["frame_embs"] + frame_embs_lens = encoder_outs.pop("frame_embs_lens") + + frame_embs = self.projection(frame_embs) + # frame_embs shape: (bsize, emb_size, time_size) + + time_dim = -1 + frame_embs_max_len = max(frame_embs_lens.max(), frame_embs.shape[time_dim]) + frame_embs_pad_mask = lengths_to_pad_mask(frame_embs_lens, frame_embs_max_len) + + encoder_outs["frame_embs"] = frame_embs + encoder_outs["frame_embs_pad_mask"] = frame_embs_pad_mask + + return encoder_outs + + def mix_audio( + self, + audio: Tensor, + audio_shape: Tensor, + indexes: Optional[Tensor], + ) -> tuple[Tensor, Tensor, Tensor]: + if indexes is None: + return audio, audio_shape, torch.full((), 1.0) + + lbd = sample_lambda( + self.hp.mixup_alpha, + asymmetric=True, + size=(), + ) + mixed_audio = audio * lbd + audio[indexes] * (1.0 - lbd) + mixed_audio_shape = torch.max(audio_shape, audio_shape[indexes]) + return mixed_audio, mixed_audio_shape, lbd diff --git a/src/conette/predict.py b/src/conette/predict.py index 3acb03744..b683074ad 100644 --- a/src/conette/predict.py +++ b/src/conette/predict.py @@ -6,10 +6,18 @@ import os.path as osp from argparse import ArgumentParser, Namespace +from typing import Optional, Union + +import torch +import transformers +import yaml from lightning_fabric.utilities.seed import seed_everything +from omegaconf import OmegaConf, DictConfig +from conette.nn.functional.get import get_device from conette.huggingface.model import CoNeTTEConfig, CoNeTTEModel +from conette.pl_modules.conette import CoNeTTEPLM from conette.utils.cmdline import _str_to_opt_str, _str_to_opt_int, _setup_logging from conette.utils.csum import csum_module @@ -39,10 +47,16 @@ def get_predict_args() -> Namespace: ) parser.add_argument( "--model_name", - type=str, + type=_str_to_opt_str, help="Model name on huggingface.", default="Labbeti/conette", ) + parser.add_argument( + "--model_path", + type=_str_to_opt_str, + help="Path to trained model directory.", + default=None, + ) parser.add_argument( "--device", type=str, @@ -77,6 +91,81 @@ def get_predict_args() -> Namespace: return args +def _load_hf_model( + model_name: str, + token: Optional[str], + device: Union[str, torch.device, None], + verbose: int, +) -> CoNeTTEModel: + if verbose >= 1: + pylog.info(f"Initilizing '{model_name}' model...") + + # To support transformers < 4.35, which is required for aac-metrics dependancy + major, minor, _patch = map(int, transformers.__version__.split(".")) + if major < 4 or (major == 4 and minor < 35): + token_key = "use_auth_token" + else: + token_key = "token" + + common_args = { + "pretrained_model_name_or_path": model_name, + token_key: token, + } + config = CoNeTTEConfig.from_pretrained(**common_args) + hf_model: CoNeTTEModel = CoNeTTEModel.from_pretrained( # type: ignore + config=config, + device=device, + **common_args, + ) + if verbose >= 1: + pylog.info(f"Model '{model_name}' is initialized.") + return hf_model + + +def _load_model_from_path( + model_path: str, + device: Union[str, torch.device, None], + verbose: int, +) -> CoNeTTEModel: + if verbose >= 1: + pylog.info(f"Initilizing model from '{model_path}'...") + + cfg_fpath = osp.join(model_path, "hydra", "config.yaml") + ckpt_fpath = osp.join(model_path, "checkpoints", "best.ckpt") + + if not osp.isfile(cfg_fpath): + raise FileNotFoundError( + f"Cannot find config file in model_path directory. ({cfg_fpath} is not a file)" + ) + if not osp.isfile(ckpt_fpath): + raise FileNotFoundError( + f"Cannot find checkpoint file in model_path directory. ({ckpt_fpath} is not a file)" + ) + + with open(cfg_fpath, "r") as file: + raw_cfg = yaml.safe_load(file) + cfg: DictConfig = OmegaConf.create(raw_cfg) # type: ignore + + pl_cfg = cfg.get("pl", {}) + target = pl_cfg.pop("_target_", "unknown") + if CoNeTTEPLM.__name__ not in target: + raise NotImplementedError(f"Unsupported pretrained model type '{target}'.") + + model = CoNeTTEPLM(**pl_cfg) + + ckpt_data = torch.load(ckpt_fpath, map_location=model.device) + state_dict = ckpt_data["state_dict"] + model.load_state_dict(state_dict, strict=True) + + device = get_device(device) + config = CoNeTTEConfig(**pl_cfg) + hf_model = CoNeTTEModel(config, device=device, model_override=model) + + if verbose >= 1: + pylog.info(f"Model from '{model_path}' is initialized.") + return hf_model + + def main_predict() -> None: """Main entrypoint for CoNeTTE predict.""" args = get_predict_args() @@ -86,24 +175,19 @@ def main_predict() -> None: fpaths = list(args.audio) tasks = args.task - if args.verbose >= 1: - pylog.info(f"Initilizing '{args.model_name}' model...") + if args.model_path is not None: + hf_model = _load_model_from_path(args.model_path, args.device, args.verbose) + elif args.model_name is not None: + hf_model = _load_hf_model( + args.model_name, args.token, args.device, args.verbose + ) + else: + raise ValueError( + f"Invalid arguments {args.model_name=} and {args.model_path=}. (expected at one str value)" + ) - config = CoNeTTEConfig.from_pretrained( - args.model_name, - token=args.token, - ) - hf_model: CoNeTTEModel = CoNeTTEModel.from_pretrained( # type: ignore - args.model_name, - config=config, - device=args.device, - token=args.token, - ) hf_model.eval_and_detach() - if args.verbose >= 1: - pylog.info(f"Model '{args.model_name}' is initialized.") - if args.verbose >= 2: enc_csum = csum_module(hf_model.preprocessor.encoder, with_names=False) model_csum = csum_module(hf_model.model, with_names=False) diff --git a/src/conette/prepare.py b/src/conette/prepare.py new file mode 100644 index 000000000..45468e9de --- /dev/null +++ b/src/conette/prepare.py @@ -0,0 +1,627 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os + +os.environ["MKL_NUM_THREADS"] = "2" +os.environ["NUMEXPR_NUM_THREADS"] = "2" +os.environ["OMP_NUM_THREADS"] = "2" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +os.environ["TRANSFORMERS_OFFLINE"] = "FALSE" +os.environ["HF_HUB_OFFLINE"] = "FALSE" + +import logging +import math +import os.path as osp +import random +import subprocess +import sys +import time + +from subprocess import CalledProcessError +from typing import Any + +import hydra +import nltk +import spacy +import torch +import torchaudio +import yaml + +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig +from torch import nn +from torch.hub import download_url_to_file +from torchaudio.backend.common import AudioMetaData + +from aac_datasets.datasets.audiocaps import AudioCaps, AudioCapsCard, _AUDIOCAPS_LINKS +from aac_datasets.datasets.clotho import Clotho, ClothoCard +from aac_datasets.datasets.macs import MACS, MACSCard +from aac_datasets.datasets.wavcaps import WavCaps +from aac_metrics.download import download_metrics as download_aac_metrics + +from conette.callbacks.stats_saver import save_to_dir +from conette.datamodules.common import get_hdf_fpaths +from conette.datasets.hdf import HDFDataset, pack_to_hdf +from conette.datasets.typing import AACDatasetLike +from conette.datasets.utils import ( + AACSubset, + AACSelectColumnsWrapper, + TransformWrapper, + load_audio_metadata, +) +from conette.nn.functional.misc import count_params +from conette.nn.cnext_ckpt_utils import CNEXT_PRETRAINED_URLS +from conette.nn.pann_utils.hub import PANN_PRETRAINED_URLS +from conette.transforms.utils import DictTransform +from conette.utils.collections import unzip +from conette.utils.csum import csum_any +from conette.utils.disk_cache import disk_cache +from conette.utils.hydra import setup_resolvers, get_subrun_path +from conette.train import setup_run, teardown_run + + +pylog = logging.getLogger(__name__) + +# Note: this function must be called globally +setup_resolvers() + + +def download_models(cfg: DictConfig) -> None: + if cfg.nltk: + # Download wordnet and omw-1.4 NLTK model for nltk METEOR metric + # Download punkt NLTK model for nltk tokenizer + # Download stopwords for constrained beam seach generation + for model_name in ( + "wordnet", + "omw-1.4", + "punkt", + "averaged_perceptron_tagger", + "stopwords", + ): + nltk.download(model_name) + + if cfg.spacy: + # Download spaCy model for AACTokenizer + SPACY_MODELS = ("en_core_web_sm", "fr_core_news_sm", "xx_ent_wiki_sm") + for model_name in SPACY_MODELS: + try: + _model = spacy.load(model_name) + pylog.info(f"Model '{model_name}' for spacy is already downloaded.") + except OSError: + command = [sys.executable, "-m", "spacy", "download", model_name] + try: + subprocess.check_call( + command, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + pylog.info(f"Model '{model_name}' for spacy has been downloaded.") + except (CalledProcessError, PermissionError) as err: # type: ignore + pylog.error( + f"Cannot download spaCy model '{model_name}' for tokenizer. (command '{command}' with error={err})" + ) + + if str(cfg.pann).lower() != "none": + ckpt_dir = osp.join(torch.hub.get_dir(), "checkpoints") + os.makedirs(ckpt_dir, exist_ok=True) + + def can_download(name: str, pattern: Any) -> bool: + if pattern == "all": + return True + elif isinstance(pattern, str): + return name.lower() == pattern.lower() + elif isinstance(pattern, list): + return name.lower() in [pann_name.lower() for pann_name in pattern] + elif isinstance(pattern, (bool, int)): + return can_download(name, "all" if pattern else "none") + else: + raise TypeError( + f"Invalid cfg.pann argument. Must be a string, a list of strings, a bool or an int, found {pattern.__class__.__name__}." + ) + + urls = { + name: model_info + for name, model_info in PANN_PRETRAINED_URLS.items() + if can_download(name, cfg.pann) + } + + for i, (name, model_info) in enumerate(urls.items()): + fpath = osp.join(ckpt_dir, model_info["fname"]) + + if osp.isfile(fpath): + pylog.info( + f"Model '{name}' already downloaded in '{fpath}'. ({i+1}/{len(urls)})" + ) + else: + pylog.info( + f"Start downloading pre-trained PANN model '{name}' ({i+1}/{len(urls)})..." + ) + download_url_to_file( + model_info["url"], fpath, progress=cfg.verbose >= 1 + ) + pylog.info(f"Model '{name}' downloaded in '{fpath}'.") + + if cfg.cnext: + ckpt_dpath = osp.join(torch.hub.get_dir(), "checkpoints") + urls = CNEXT_PRETRAINED_URLS + + for i, (name, info) in enumerate(urls.items()): + url = info["url"] + fname = info["fname"] + fpath = osp.join(ckpt_dpath, fname) + + if osp.isfile(fpath): + pylog.info( + f"Model '{name}' already downloaded in '{fpath}'. ({i+1}/{len(urls)})" + ) + else: + pylog.info( + f"Start downloading pre-trained CNext model '{name}' ({i+1}/{len(urls)})..." + ) + download_url_to_file(url, fpath, progress=cfg.verbose >= 1) + + +def download_dataset(cfg: DictConfig) -> dict[str, AACDatasetLike]: + # Download a dataset + hydra_cfg = HydraConfig.get() + dataname = hydra_cfg.runtime.choices["data"] + + dsets: dict[str, Any] = {} + + dataroot: str = cfg.data.root + dataroot = osp.expandvars(dataroot) + dataroot = osp.expanduser(dataroot) + os.makedirs(dataroot, exist_ok=True) + + if dataname == "audiocaps": + AudioCaps.FORCE_PREPARE_DATA = False + + if cfg.data.subsets is None: + subsets = AudioCapsCard.SUBSETS + else: + subsets = cfg.data.subsets + + if cfg.data.audiocaps_caps_fix_fpath is not None: + if "train" not in subsets: + pylog.error( + f"Invalid combinaison of arguments {cfg.data.audiocaps_caps_fix_fpath=} with {subsets=}." + ) + else: + subsets = list(subsets) + subsets.remove("train") + new_subset = osp.basename(cfg.data.audiocaps_caps_fix_fpath)[:-4] + subsets.append(new_subset) + + AudioCaps.SUBSETS = AudioCaps.SUBSETS + (new_subset,) # type: ignore + _AUDIOCAPS_LINKS.update( + { + new_subset: { + "captions": { + "url": None, + "fname": osp.basename( + cfg.data.audiocaps_caps_fix_fpath + ), + }, + }, + } + ) + + for subset in subsets: + dsets[subset] = AudioCaps( + dataroot, + subset, + download=cfg.data.download, + verbose=cfg.verbose, + with_tags=cfg.data.with_tags, + ffmpeg_path=cfg.path.ffmpeg, + ytdl_path=cfg.path.ytdl, + ) + + elif dataname == "clotho": + Clotho.FORCE_PREPARE_DATA = False + Clotho.CLEAN_ARCHIVES = cfg.data.clean_archives + + if cfg.data.subsets is None: + subsets = ClothoCard.SUBSETS + else: + subsets = cfg.data.subsets + + for subset in subsets: + dsets[subset] = Clotho( + dataroot, + subset, + download=cfg.data.download, + verbose=cfg.verbose, + version=cfg.data.version, + ) + + elif dataname == "macs": + MACS.FORCE_PREPARE_DATA = False + MACS.CLEAN_ARCHIVES = cfg.data.clean_archives + + if cfg.data.subsets is None: + subsets = MACSCard.SUBSETS + else: + subsets = cfg.data.subsets + + for subset in subsets: + dsets[subset] = MACS( + dataroot, + subset=subset, + download=cfg.data.download, + verbose=cfg.verbose, + ) + + if cfg.data.tags_to_str: + dsets = { + subset: TransformWrapper(dset, str, "tags") + for subset, dset in dsets.items() + } + + elif dataname == "hdf": + hdf_fpaths = get_hdf_fpaths( + cfg.data.name, cfg.data.subsets, dataroot, cfg.data.hdf_suffix + ) + dsets = {} + for subset, hdf_fpath in hdf_fpaths.items(): + ds = HDFDataset(hdf_fpath) + ds = AACSelectColumnsWrapper(ds, include=cfg.data.include_columns) + dsets[subset] = ds + + elif dataname == "wavcaps": + if cfg.data.subsets is None: + subsets = ("as_bbc_sb",) + else: + subsets = cfg.data.subsets + + dsets = { + subset: WavCaps( + dataroot, + subset, + download=cfg.data.download, + hf_cache_dir=cfg.data.hf_cache_dir, + verbose=cfg.verbose, + ) + for subset in subsets + } + + elif dataname in ("none",): + dsets = {} + + else: + accepted_datasets = ( + "audiocaps", + "clotho", + "hdf", + "macs", + "wavcaps", + "none", + ) + raise RuntimeError( + f"Unknown dataset '{dataname}'. Expected one of {accepted_datasets}." + ) + + dsets = filter_dsets(cfg, dsets) + + if cfg.verbose >= 2 and len(dsets) > 0: + rand_subset = random.choice(list(dsets.keys())) + dset = dsets[rand_subset] + if len(dset) > 0: + rand_idx = random.randint(0, len(dset) - 1) + meta_lst = dset.at(rand_idx, "audio_metadata") + pylog.debug(f"Sample random metadata from subset '{rand_subset}':") + pylog.debug(f"{meta_lst}") + + return dsets + + +def filter_dsets( + cfg: DictConfig, + dsets: dict[str, AACDatasetLike], +) -> dict[str, AACDatasetLike]: + min_audio_size = float(cfg.datafilter.min_audio_size) + max_audio_size = float(cfg.datafilter.max_audio_size) + use_range_filt = cfg.datafilter.imin is not None or cfg.datafilter.imax is not None + use_duration_filt = min_audio_size > 0.0 or not math.isinf(max_audio_size) + use_sr_filt = cfg.datafilter.sr is not None + + if not any((use_range_filt, use_duration_filt, use_sr_filt)): + return dsets + + indexes_dic: dict[str, list[int]] = {} + for subset, ds in dsets.items(): + indexes_dic[subset] = list(range(len(ds))) + + meta_dic: dict[str, list[AudioMetaData]] = {} + + if use_duration_filt or use_sr_filt: + for subset, ds in dsets.items(): + fpaths = ds[:, "fpath"] + if cfg.verbose >= 2: + pylog.debug(f"Loading durations from {len(ds)} audio files...") + meta_lst = disk_cache( + load_audio_metadata, fpaths, cache_path=cfg.path.cache + ) + meta_dic[subset] = list(meta_lst.values()) + + if use_range_filt: + if cfg.verbose >= 1: + pylog.info( + f"Limit datasets in [{cfg.datafilter.imin}, {cfg.datafilter.imax}]." + ) + + imin = cfg.datafilter.imin + imax = cfg.datafilter.imax + indexes_dic = { + subset: indexes[imin:imax] for subset, indexes in indexes_dic.items() + } + + if use_duration_filt: + for subset, indexes in indexes_dic.items(): + meta_lst = meta_dic[subset] + meta_lst = [meta_lst[idx] for idx in indexes] + durations = [(meta.num_frames / meta.sample_rate) for meta in meta_lst] + prev_size = len(indexes) + indexes_and_durations = [ + (idx, duration) + for idx, duration in zip(indexes, durations, strict=True) + if min_audio_size <= duration <= max_audio_size + ] + indexes, durations = unzip(indexes_and_durations) + indexes_dic[subset] = indexes + + n_excluded = prev_size - len(indexes) + if cfg.verbose >= 1: + pylog.info( + f"Exclude {n_excluded}/{prev_size} files with audio size not in [{min_audio_size}, {max_audio_size}] seconds in {subset=}." + ) + pylog.info( + f"Durations are now in range [{min(durations):.2f}, {max(durations):.2f}] s." + ) + + if use_sr_filt: + for subset, indexes in indexes_dic.items(): + meta_lst = meta_dic[subset] + meta_lst = [meta_lst[idx] for idx in indexes] + sample_rates = [meta.sample_rate for meta in meta_lst] + prev_size = len(indexes) + indexes = [ + idx + for idx, sr in zip(indexes, sample_rates, strict=True) + if sr == cfg.datafilter.sr + ] + indexes_dic[subset] = indexes + + n_excluded = prev_size - len(indexes) + if cfg.verbose >= 1: + pylog.info( + f"Exclude {n_excluded}/{prev_size} files with sample_rate != {cfg.datafilter.sr} Hz in {subset=}." + ) + + dsets = {subset: AACSubset(ds, indexes_dic[subset]) for subset, ds in dsets.items()} + return dsets + + +def pack_dsets_to_hdf(cfg: DictConfig, dsets: dict[str, Any]) -> None: + if not cfg.pack_to_hdf: + return None + + hydra_cfg = HydraConfig.get() + dataname = hydra_cfg.runtime.choices["data"] + audio_transform_name = hydra_cfg.runtime.choices["audio_t"] + sentence_transform_name = hydra_cfg.runtime.choices["text_t"] + + if len(dsets) == 0: + pylog.warning( + f"Invalid value {dataname=} with pack_to_hdf=true. (found {len(dsets)} datasets)" + ) + return None + + if hasattr(cfg.audio_t, "src_sr"): + src_sr = cfg.audio_t.src_sr + for name, dset in dsets.items(): + if ( + isinstance(dset, HDFDataset) + or not isinstance(dset, AACDatasetLike) + or "fpath" not in dset.column_names + or len(dset) == 0 + ): + continue + fpath = dset[0, "fpath"] + meta = torchaudio.info(fpath) # type: ignore + if src_sr != meta.sample_rate: + raise ValueError( + f"Invalid input sr {src_sr} with audio sr {meta.sample_rate}. (with dataset '{name}')" + ) + + dataroot: str = cfg.path.data + dataroot = osp.expandvars(dataroot) + dataroot = osp.expanduser(dataroot) + hdf_root = osp.join(dataroot, "HDF") + os.makedirs(hdf_root, exist_ok=True) + + for subset, dset in dsets.items(): + audio_transform_params = dict(cfg.audio_t) + sentence_transform_params = dict(cfg.text_t) + + audio_tfm = hydra.utils.instantiate(audio_transform_params) + text_tfm = hydra.utils.instantiate(sentence_transform_params) + + if isinstance(audio_tfm, nn.Module) and cfg.verbose >= 1: + n_params = count_params(audio_tfm, only_trainable=False) + pylog.info(f"Nb params in audio transform: {n_params}") + + if isinstance(text_tfm, nn.Module) and cfg.verbose >= 1: + n_params = count_params(text_tfm, only_trainable=False) + pylog.info(f"Nb params in text transform: {n_params}") + + pre_save_transforms = { + "audio": audio_tfm, + "captions": text_tfm, + } + transforms_params = { + "audio": audio_transform_params, + "captions": sentence_transform_params, + } + if cfg.csum_in_hdf_name: + csum = csum_any(transforms_params) % 1000 + csum_suffix = f"_{csum}" + else: + csum_suffix = "" + + hdf_fname = f"{dataname}_{subset}_{audio_transform_name}_{sentence_transform_name}{csum_suffix}.hdf" + + if cfg.datafilter.imin is not None or cfg.datafilter.imax is not None: + hdf_fname = hdf_fname.replace( + ".hdf", f"_lim_{cfg.datafilter.imin}_{cfg.datafilter.imax}.hdf" + ) + if cfg.post_hdf_name is not None: + hdf_fname = hdf_fname.replace(".hdf", f"_{cfg.post_hdf_name}.hdf") + hdf_fpath = osp.join(hdf_root, hdf_fname) + + if not osp.isfile(hdf_fpath) or cfg.overwrite_hdf: + if cfg.verbose >= 1: + pylog.info( + f"Start packing the {dataname}_{subset} dataset to HDF file {hdf_fname}..." + ) + + metadata = { + "transform_params": transforms_params, + } + if hasattr(cfg.audio_t, "tgt_sr"): + metadata["sr"] = cfg.audio_t.tgt_sr + + if cfg.verbose >= 1: + pylog.debug(yaml.dump({"Metadata": metadata})) + + pre_save_transform = DictTransform(pre_save_transforms) + + hdf_dset = pack_to_hdf( + dset, + hdf_fpath, + pre_save_transform, # type: ignore + overwrite=cfg.overwrite_hdf, + metadata=str(metadata), + verbose=cfg.verbose, + loader_bsize=cfg.data.bsize, + loader_n_workers=cfg.data.n_workers, + ) + hdf_dset.open() + else: + if cfg.verbose >= 1: + pylog.info( + f"Dataset {dataname}_{subset} is already packed to hdf in {hdf_fpath=}." + ) + + hdf_dset = HDFDataset(hdf_fpath) + + if cfg.debug: + # Sanity check + idx = int(torch.randint(len(dset), ()).item()) + + dset_item: dict[str, Any] = dict(dset[idx]) + for name, transform in pre_save_transforms.items(): + if name in dset_item.keys() and transform is not None: + dset_item[name] = transform(dset_item[name]) + hdf_item = hdf_dset[idx] + + dset_keys_in_hdf_keys = all( + key in hdf_item.keys() for key in dset_item.keys() + ) + same_dset_len = len(dset) == len(hdf_dset) + + pylog.debug(f"Check with item N°{idx=}") + pylog.debug( + f"Check {dset_keys_in_hdf_keys=} ({dset_item.keys()} in {hdf_item})" + ) + pylog.debug(f"Check {same_dset_len=} ({len(dset)} == {len(hdf_dset)})") + + all_same = True + + if "audio" in dset_item.keys(): + rtol = 10**-3 + dset_audio, hdf_audio = dset_item["audio"], hdf_item["audio"] + same_audio_shape = dset_audio.shape == hdf_audio.shape + close_audio = same_audio_shape and torch.allclose( + dset_audio, hdf_audio, rtol=rtol + ) + same_audio = same_audio_shape and dset_audio.eq(hdf_audio).all().item() + all_same = all_same and close_audio and same_audio + + pylog.debug( + f"Check {same_audio_shape=} ({dset_audio.shape} == {hdf_audio.shape})" + ) + pylog.debug(f"Check {close_audio=} ({rtol=})") + pylog.debug(f"Check {same_audio=}") + + if "captions" in dset_item.keys(): + dset_captions, hdf_captions = ( + dset_item["captions"], + hdf_item["captions"], + ) + same_captions = len(dset_captions) == len(hdf_captions) and all( + c1 == c2 for c1, c2 in zip(dset_captions, hdf_captions) + ) + captions_eq = ( + f"(\n{dset_captions}\n == \n{hdf_captions}\n)" + if not same_captions + else "" + ) + all_same = all_same and same_captions + + pylog.debug(f"Check {same_captions=} {captions_eq}") + + if not all_same: + pylog.warning( + f"Check has failed after packing {dataname} to HDF. (dataset={dset.__class__.__name__}, {subset=})\n" + f"NOTE: if a transform is stochastic, you can ignore this warning." + ) + + +@hydra.main( + version_base=None, + config_path=osp.join("..", "conf"), + config_name="prepare", +) +def main_prepare(cfg: DictConfig) -> None: + """Download models and datasets.""" + run_start = time.perf_counter() + setup_run(cfg) + + # Add JAVA to PATH for language tool usage on Osirim + java_dir = osp.dirname(cfg.path.java) + if java_dir not in ("", "."): + os.environ["PATH"] += f"{os.pathsep}{java_dir}" + + download_models(cfg) + dsets = download_dataset(cfg) + pack_dsets_to_hdf(cfg, dsets) + + # Download AAC metrics + download_aac_metrics( + cache_path=cfg.path.cache, + tmp_path=cfg.path.tmp, + ptb_tokenizer=cfg.ptb_tokenizer, + meteor=cfg.meteor, + spice=cfg.spice, + fense=cfg.fense, + verbose=cfg.verbose, + ) + + subrun_path = get_subrun_path() + save_to_dir( + subrun_path=subrun_path, + tokenizers=None, # type: ignore + git_hash=cfg.git_hash, + cfg=cfg, + verbose=cfg.verbose, + ) + + run_end = time.perf_counter() + teardown_run(cfg, run_start, run_end) + + +if __name__ == "__main__": + main_prepare() diff --git a/src/conette/retrieve.py b/src/conette/retrieve.py new file mode 100755 index 000000000..9f2511e03 --- /dev/null +++ b/src/conette/retrieve.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os + +os.environ["MKL_NUM_THREADS"] = "2" +os.environ["NUMEXPR_NUM_THREADS"] = "2" +os.environ["OMP_NUM_THREADS"] = "2" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +os.environ["TRANSFORMERS_OFFLINE"] = "TRUE" +os.environ["HF_HUB_OFFLINE"] = "TRUE" + +import glob +import logging +import os.path as osp +import pickle +import time + +from typing import Any + +import hydra +import torch +import tqdm +import yaml + +from omegaconf import DictConfig +from torch import Tensor +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader + +from aac_metrics.utils.collections import flat_list +from aac_metrics.utils.tokenization import preprocess_mult_sents + +from conette.callbacks.stats_saver import save_to_dir +from conette.datamodules.collate import AdvancedCollateDict +from conette.datasets.hdf import HDFDataset +from conette.datasets.utils import LambdaDataset +from conette.metrics.retrieval import retrieval_metrics +from conette.nn.functional.get import get_device +from conette.nn.functional.mask import masked_mean +from conette.nn.functional.misc import move_to_rec +from conette.utils.csv_utils import save_csv_list +from conette.utils.dcase import export_to_dcase_task6b_csv +from conette.utils.hydra import setup_resolvers, get_subrun_path +from conette.utils.yaml_utils import save_yaml +from conette.predict import load_pl_module +from conette.train import setup_run, teardown_run + + +# Note: this function must be called globally +setup_resolvers() + +pylog = logging.getLogger(__name__) + + +def sizes_to_matrix(sizes: list[int]) -> Tensor: + n_audios = len(sizes) + n_caps = sum(sizes) + matrix = torch.full((n_audios, n_caps), False, dtype=torch.bool) + count = 0 + for idx, size in enumerate(sizes): + matrix[idx, count : count + size] = True + count += size + return matrix + + +def scale_losses(losses: Tensor) -> Tensor: + """(queries, targets), A2T is ok but T2A requires transpose before and after.""" + EPSILON = 1e-5 + mins = losses.min(dim=0).values + maxs = losses.max(dim=0).values + scaled_losses = (losses - mins) / (maxs - mins) + + n_queries = scaled_losses.shape[0] + + for query_idx in range(n_queries): + scores = scaled_losses[query_idx] + zero_mask = scores == scores.min() + if zero_mask.sum() == 0: + continue + orig_losses = losses[query_idx, zero_mask] + ranks = orig_losses.argsort(descending=False).to(dtype=scaled_losses.dtype) + # if len(ranks) = 3, rank 0 -> -3, rank 1 -> -2 and rank 2 -> -1. (lower value is better for losses) + scaled_losses[query_idx, zero_mask] = ( + scores.min() + (ranks - ranks.shape[0]) * EPSILON + ) + + return scaled_losses + + +@hydra.main( + version_base=None, + config_path=osp.join("..", "conf"), + config_name="retrieve", +) +def run_retrieve(cfg: DictConfig) -> None: + run_start = time.perf_counter() + setup_run(cfg) + + resumes = cfg.resume + if isinstance(resumes, str): + resumes = [resumes] + + if isinstance(resumes, list): + logdirs = [match for logdir in resumes for match in glob.glob(logdir)] + else: + raise TypeError(f"Invalid resume type {type(resumes)}.") + + hdf_root = osp.join(cfg.path.data, "HDF") + hdf_fnames = cfg.hdf_fnames + if isinstance(hdf_fnames, str): + hdf_fnames = [hdf_fnames] + + logdirs = list(sorted(logdirs)) + if cfg.verbose >= 1: + result_dnames = [ + osp.join(osp.basename(osp.dirname(logdir)), osp.basename(logdir)) + for logdir in logdirs + ] + print(f"Found {len(logdirs)} logdirs:") + print(yaml.dump(result_dnames, sort_keys=False)) + + if len(logdirs) <= 0: + pylog.warning(f"No pre-trained model has been found in {cfg.resume}.") + run_end = time.perf_counter() + teardown_run(cfg, run_start, run_end) + return None + + device = get_device(cfg.device) + plms: dict[str, Any] = { # type: ignore + logdir: load_pl_module(logdir, device=device) for logdir in tqdm.tqdm(logdirs) + } + plm0 = next(iter(plms.values())) + tokenizer0 = plm0.tokenizer # assume all tokenizers are the same + assert all(tokenizer0 == plm.tokenizer for plm in plms.values()) + + dsets = {fname: HDFDataset(osp.join(hdf_root, fname)) for fname in hdf_fnames} + mrefs_dic = {ds_name: dset[:, "captions"] for ds_name, dset in dsets.items()} + mrefs_dic = { + ds_name: preprocess_mult_sents( + mrefs, cfg.path.cache, cfg.path.java, cfg.path.tmp, verbose=cfg.verbose + ) + for ds_name, mrefs in mrefs_dic.items() + } + + flat_mrefs_and_sizes_dic = { + ds_name: flat_list(mrefs) for ds_name, mrefs in mrefs_dic.items() + } + flat_mrefs_dic = { + ds_name: refs for ds_name, (refs, _sizes) in flat_mrefs_and_sizes_dic.items() + } + is_matching_matrices = { + ds_name: sizes_to_matrix(sizes) + for ds_name, (_refs, sizes) in flat_mrefs_and_sizes_dic.items() + } + del flat_mrefs_and_sizes_dic + + captions_dic: dict[str, Tensor] = {ds_name: tokenizer0.encode_batch(refs, padding="batch") for ds_name, refs in flat_mrefs_dic.items()} # type: ignore + captions_dic = { + ds_name: queries.unsqueeze(dim=1).repeat(1, cfg.bsize, 1).to(device=device) + for ds_name, queries in captions_dic.items() + } + + caps_in_dic = { + ds_name: queries[:, :, :-1].contiguous() + for ds_name, queries in captions_dic.items() + } + caps_out_dic = { + ds_name: queries[:, :, 1:].contiguous() + for ds_name, queries in captions_dic.items() + } + n_caps_dic = { + ds_name: len(queries_in) for ds_name, queries_in in caps_in_dic.items() + } + + collator = AdvancedCollateDict(pad_values=dict(audio=0.0)) + + def build_dset(dset) -> LambdaDataset: + def get_item(idx: int) -> dict[str, Any]: + return { + "audio": dset[idx, "audio"], + "audio_shape": dset[idx, "audio_shape"], + "fname": dset[idx, "fname"], + "index": torch.as_tensor(idx), + "dataset": dset[idx, "dataset"], + } + + lbd_dset = LambdaDataset(get_item, len(dset)) + return lbd_dset + + n_audios_dic = {ds_name: len(dset) for ds_name, dset in dsets.items()} + lbd_dsets = {ds_name: build_dset(dset) for ds_name, dset in dsets.items()} + loaders = { + ds_name: DataLoader( + lbd_dset, + batch_size=cfg.bsize, + shuffle=False, + collate_fn=collator, + num_workers=cfg.n_workers, + ) + for ds_name, lbd_dset in lbd_dsets.items() + } + + # Compute losses + all_losses: dict[str, Tensor] = {} + DEFAULT_SCORE = -999.0 + + for ds_idx, (ds_fname, loader) in enumerate(loaders.items()): + caps_in = caps_in_dic[ds_fname] + caps_out = caps_out_dic[ds_fname] + n_caps = n_caps_dic[ds_fname] + n_audios = n_audios_dic[ds_fname] + + ds_losses = torch.full( + (n_audios, len(plms), n_caps), + DEFAULT_SCORE, + dtype=torch.float, + device=device, + ) + + for batch_idx, batch in enumerate(tqdm.tqdm(loader, disable=True)): + batch = move_to_rec(batch, device=device) + batch = plm0.on_after_batch_transfer(batch, ds_idx) + + audio = batch["audio"] + audio_shape = batch["audio_shape"] + indexes = batch["index"] + + cur_bsize = len(audio) + batch_losses = torch.full( + (cur_bsize, len(plms), n_caps), + DEFAULT_SCORE, + dtype=torch.float, + device=device, + ) + + for plm_idx, (plm_name, plm) in enumerate(plms.items()): + if cfg.verbose >= 1: + pylog.info( + f"{ds_idx=}/{len(loaders)}, {batch_idx=}/{len(loader)}, {plm_idx=}/{len(plms)}" + ) + + batch_plm_losses = torch.full( + (cur_bsize, n_caps), + DEFAULT_SCORE, + dtype=torch.float, + device=device, + ) + + enc_outs = plm.encode_audio(audio, audio_shape) + + pbar = tqdm.tqdm(caps_in) + for cap_idx, (cap_in, cap_out) in enumerate(zip(pbar, caps_out)): + if cur_bsize < cap_in.shape[0]: + cap_in = cap_in[:cur_bsize] + cap_out = cap_out[:cur_bsize] + + logits = plm.decode_audio(enc_outs, "forcing", caps_in=cap_in) + losses = F.cross_entropy( + logits, + cap_out, + ignore_index=plm.pad_id, + reduction="none", + weight=None, + ) + losses = masked_mean(losses, cap_out != plm.pad_id, dim=1) + # losses: (bsize,) + batch_plm_losses[:, cap_idx] = losses + if cfg.debug: + break + + batch_losses[:, plm_idx] = batch_plm_losses + if cfg.debug: + break + + ds_losses[indexes] = batch_losses + if cfg.debug: + break + + all_losses[ds_fname] = ds_losses + + all_losses = {ds_name: ds_losses.cpu() for ds_name, ds_losses in all_losses.items()} + + # Save losses matrix + subrun_path = get_subrun_path() + os.makedirs(osp.join(subrun_path, "dcase"), exist_ok=True) + os.makedirs(osp.join(subrun_path, "metrics"), exist_ok=True) + + all_retrieval_scores: dict[str, list[dict[str, Any]]] = {"t2a": [], "a2t": []} + + for ds_fname, ds_losses in all_losses.items(): + ds = dsets[ds_fname] + ds_name = ds[0, "dataset"] + ds_subset = ds[0, "subset"] + captions = flat_mrefs_dic[ds_fname] + + # Dump losses + # ds_losses: (n_audios, n_plms, n_queries) + for plm_idx, plm_name in enumerate(plms.keys()): + plm_losses = ds_losses[:, plm_idx].contiguous() + + fname = f"{ds_fname.replace('.hdf', '')}-plm_idx_{plm_idx}-losses.pickle" + fpath = osp.join(subrun_path, fname) + data = { + "losses": plm_losses, + "captions": captions, + "audio_fnames": dsets[ds_fname][:, "fname"], + "ds_name": ds_fname, + "plm_idx": plm_idx, + "plm_name": plm_name, + } + with open(fpath, "wb") as file: + pickle.dump(data, file) + del plm_losses + + # Compute retrieval results + tasks_and_modes = [("t2a", cfg.t2a_modes), ("a2t", cfg.a2t_modes)] + tasks_and_modes = [ + (task, [modes] if isinstance(modes, str) else list(modes)) + for task, modes in tasks_and_modes + ] + for task, modes in tasks_and_modes: + for mode in modes: + for plm_idx, plm_name in enumerate(plms.keys()): + plm_losses = ds_losses[:, plm_idx].contiguous() + if task == "t2a": + plm_losses = ds_losses[:, plm_idx].transpose(0, 1) + is_matching = is_matching_matrices[ds_fname].transpose(0, 1) + elif task == "a2t": + plm_losses = ds_losses[:, plm_idx] + is_matching = is_matching_matrices[ds_fname] + else: + raise RuntimeError(f"Invalid value {task=}.") + # plm_losses: (targets, queries) + + if mode == "loss": + plm_scores = -plm_losses + + elif mode == "scaled_loss": + scaled_losses = scale_losses(plm_losses) + plm_scores = -scaled_losses + + n_max_per_audio = ( + plm_scores == plm_scores.max(dim=1).values[:, None] + ).sum(dim=1) + assert ( + n_max_per_audio.eq(1).all().item() + ) # check if there is always an unique top1 + + else: + raise ValueError(f"Invalid argument {mode=}.") + + if is_matching is not None: + retrieval_outs_corpus, _retrieval_outs_sents = retrieval_metrics(plm_scores, is_matching) # type: ignore + retrieval_outs_corpus: dict[str, Tensor] + + retrieval_outs_corpus_lst = { + k: v.tolist() for k, v in retrieval_outs_corpus.items() + } + + yaml_fname = f"{task}-{ds_name}_{ds_subset}-mode_{mode}-plm_idx_{plm_idx}.yaml" + yaml_fpath = osp.join(subrun_path, "metrics", yaml_fname) + save_yaml(retrieval_outs_corpus_lst, yaml_fpath) + + if cfg.verbose >= 1: + pylog.info( + f"Audio-Text retrieval scores for {task.upper()} with {mode=}:\n{yaml.dump(retrieval_outs_corpus_lst, sort_keys=False)}" + ) + + all_retrieval_scores[task].append( + { + "ds_name": ds_name, + "ds_subset": ds_subset, + "ds_fname": ds_fname, + "mode": mode, + "plm_idx": plm_idx, + "plm_name": plm_name, + } + | retrieval_outs_corpus_lst + ) + + if task == "t2a": + csv_fname = f"{task}-{ds_name}_{ds_subset}-mode_{mode}-plm_idx_{plm_idx}.csv" + csv_fpath = osp.join(subrun_path, "dcase", csv_fname) + top_audio_indexes = plm_scores.argsort(dim=1, descending=True)[ + :, :10 + ].tolist() + predicted_fnames = [ + ds[indexes, "fname"] for indexes in top_audio_indexes + ] + export_to_dcase_task6b_csv( + csv_fpath, captions, predicted_fnames + ) + + for task, scores in all_retrieval_scores.items(): + csv_fname = f"metrics-{task}.csv" + csv_fpath = osp.join(subrun_path, csv_fname) + save_csv_list(scores, csv_fpath) + + flatten_scores = [ + {"task": task} | dic + for task, scores in all_retrieval_scores.items() + for dic in scores + ] + csv_fpath = osp.join(subrun_path, "metrics.csv") + save_csv_list(flatten_scores, csv_fpath) + + run_end = time.perf_counter() + save_to_dir( + subrun_path, + git_hash=cfg.git_hash, + cfg=cfg, + verbose=cfg.verbose, + ) + teardown_run(cfg, run_start, run_end) + + +if __name__ == "__main__": + run_retrieve() diff --git a/src/conette/tokenization/tokenizers/ptb.py b/src/conette/tokenization/tokenizers/ptb.py new file mode 100644 index 000000000..36a42d658 --- /dev/null +++ b/src/conette/tokenization/tokenizers/ptb.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from pathlib import Path +from typing import Iterable, Union + +from aac_metrics.utils.tokenization import ptb_tokenize_batch + +from conette.tokenization.constants import SPECIAL_TOKENS +from conette.tokenization.tokenizers.base import StrTokenizer +from conette.tokenization.tokenizers.common import build_mappings_and_vocab + + +class PTBWordTokenizer(StrTokenizer): + def __init__( + self, + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, + special_tokens: Iterable[str] = SPECIAL_TOKENS, + ) -> None: + super().__init__() + self._cache_path = cache_path + self._java_path = java_path + self._tmp_path = tmp_path + self._special_tokens = special_tokens + + def detokenize_batch(self, sentences: Iterable[Iterable[str]]) -> list[str]: + decoded_sentences = [" ".join(sentence) for sentence in sentences] + return decoded_sentences + + def fit(self, sentences: Iterable[str]) -> tuple[list, dict, dict, dict]: + encoded_sentences = self.tokenize_batch(sentences) + itos, stoi, vocab = build_mappings_and_vocab( + encoded_sentences, self._special_tokens + ) + return encoded_sentences, itos, stoi, vocab + + def get_backend(self) -> str: + return "ptb" + + def get_level(self) -> str: + return "word" + + def tokenize_batch(self, sentences: Iterable[str], **kwargs) -> list[list[str]]: + return ptb_tokenize_batch( + sentences, + cache_path=self._cache_path, + java_path=self._java_path, + tmp_path=self._tmp_path, + ) diff --git a/src/conette/tokenization/tokenizers/word.py b/src/conette/tokenization/tokenizers/word.py index f2871a2e3..cf304e668 100644 --- a/src/conette/tokenization/tokenizers/word.py +++ b/src/conette/tokenization/tokenizers/word.py @@ -4,6 +4,11 @@ from typing import ClassVar from conette.tokenization.tokenizers.base import StrTokenizer + +try: + from conette.tokenization.tokenizers.ptb import PTBWordTokenizer +except ImportError: + PTBWordTokenizer = None from conette.tokenization.tokenizers.spacy import SpacyWordTokenizer from conette.tokenization.tokenizers.wrapper import TokenizerWrapper @@ -11,11 +16,12 @@ class WordTokenizer(TokenizerWrapper): """Tokenizer facade class for the following word tokenizers classes: - - :class:`~aac.tokenization.tokenizers.spacy.SpacyWordTokenizer` + - :class:`~conette.tokenization.tokenizers.ptb.PTBWordTokenizer` + - :class:`~conette.tokenization.tokenizers.spacy.SpacyWordTokenizer` """ - BACKENDS: ClassVar[tuple[str, ...]] = ("spacy",) + BACKENDS: ClassVar[tuple[str, ...]] = ("spacy", "ptb") def __init__(self, backend: str = "spacy", *args, **kwargs) -> None: tokenizer = _word_tokenizer_factory(backend, *args, **kwargs) @@ -25,6 +31,12 @@ def __init__(self, backend: str = "spacy", *args, **kwargs) -> None: def _word_tokenizer_factory(backend: str = "spacy", *args, **kwargs) -> StrTokenizer: if backend == "spacy": tokenizer = SpacyWordTokenizer(*args, **kwargs) + elif backend == "ptb": + if PTBWordTokenizer is None: + raise RuntimeError( + "Please install aac-metrics package to use ptb tokenizer backend. (found None PTBWordTokenizer)" + ) + tokenizer = PTBWordTokenizer(*args, **kwargs) else: raise ValueError( f"Invalid argument {backend=}. (expected one of {WordTokenizer.BACKENDS})" diff --git a/src/conette/train.py b/src/conette/train.py new file mode 100644 index 000000000..a4dc480d2 --- /dev/null +++ b/src/conette/train.py @@ -0,0 +1,531 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os + +os.environ["MKL_NUM_THREADS"] = "2" +os.environ["NUMEXPR_NUM_THREADS"] = "2" +os.environ["OMP_NUM_THREADS"] = "2" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +os.environ["TRANSFORMERS_OFFLINE"] = "TRUE" +os.environ["HF_HUB_OFFLINE"] = "TRUE" + +import logging +import math +import os.path as osp +import sys +import time + +from typing import Callable, Optional, Union + +import colorlog +import hydra +import torch +import yaml + +from hydra.utils import instantiate +from lightning_fabric.plugins.environments import LightningEnvironment +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.callbacks import ( + Callback, + DeviceStatsMonitor, + EarlyStopping, + ModelCheckpoint, + ModelSummary, +) +from transformers import logging as tfmers_logging + +import conette + +from conette.callbacks.aac_evaluator import AACEvaluator +from conette.callbacks.aac_validator import AACValidator +from conette.callbacks.debug import PrintDebug +from conette.callbacks.deepspeed import DeepSpeedCallback +from conette.callbacks.log import LogGCCallback, LogLRCallback, LogGradNorm, LogRngState +from conette.callbacks.resume import ResumeCallback +from conette.callbacks.stats_saver import StatsSaver +from conette.tokenization.aac_tokenizer import AACTokenizer +from conette.utils.custom_logger import CustomTensorboardLogger +from conette.utils.hydra import setup_resolvers, get_subrun_path, CustomFileHandler +from conette.utils.log_utils import set_loglevel +from conette.utils.misc import copy_slurm_logs, reset_seed + + +# Note: this function must be called globally +setup_resolvers() + +pylog = logging.getLogger(__name__) + + +# Public functions +def setup_run(cfg: DictConfig) -> None: + reset_seed(cfg.seed) + OmegaConf.resolve(cfg) + OmegaConf.set_readonly(cfg, True) + + # Print config + subrun_path = get_subrun_path() + if cfg.verbose >= 1: + pylog.info(f"Subrun: {subrun_path}") + pylog.info(f"Configuration:\n{OmegaConf.to_yaml(cfg, resolve=True)}") + + # Overwrite root logger formatter + rank = os.getenv("SLURM_PROCID", 0) + formatter = colorlog.ColoredFormatter( + f"[%(purple)sRANK{rank}%(reset)s][%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s", + log_colors={ + "DEBUG": "purple", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "bold_red", + }, + ) + root_logger = logging.getLogger() + handlers = root_logger.handlers + for handler in handlers: + handler.setFormatter(formatter) + + # Set log level in some packagers + if cfg.debug or cfg.verbose >= 2: + pkg_level = logging.DEBUG + other_level = logging.WARNING + elif cfg.verbose == 1: + pkg_level = logging.INFO + other_level = logging.ERROR + tfmers_logging.set_verbosity_error() + else: + pkg_level = logging.WARNING + other_level = logging.ERROR + tfmers_logging.set_verbosity_error() + + set_loglevel(("sentence_transformers",), other_level) + set_loglevel((conette,), pkg_level) + pylog.setLevel(pkg_level) + + # Redirect PyTorch lightning outputs to a file + os.makedirs(subrun_path, exist_ok=True) + + fpath_pl_outputs = osp.join(subrun_path, "logs", "lightning_outputs.log") + handler = CustomFileHandler(fpath_pl_outputs) + pl_pylog = logging.getLogger("pytorch_lightning") + pl_pylog.addHandler(handler) + + # Set data dir for torch + if cfg.path.torch_hub is not None: + torch_hub_path = osp.expandvars(cfg.path.torch_hub) + torch.hub.set_dir(torch_hub_path) + + # Set PyTorch sharing strategy if needed + if cfg.sharing_strategy is not None: + sharing_strategies = torch.multiprocessing.get_all_sharing_strategies() + if cfg.sharing_strategy not in sharing_strategies: + raise ValueError( + f"Invalid argument {cfg.sharing_strategy=}. (expected one of {tuple(sharing_strategies)})" + ) + if cfg.verbose >= 1: + pylog.info(f"Set sharing strategy to {cfg.sharing_strategy}.") + torch.multiprocessing.set_sharing_strategy(cfg.sharing_strategy) + + # Print debug info + if cfg.verbose >= 2: + overrides_fpath = osp.join(subrun_path, "hydra", "overrides.yaml") + if osp.isfile(overrides_fpath): + with open(overrides_fpath, "r") as file: + overrides = yaml.safe_load(file) + pylog.info(f"Overrides:\n{yaml.dump(overrides, sort_keys=False)}") + + varnames = ( + "SLURM_JOB_NAME", + "SLURM_JOB_ID", + "SLURM_NTASKS", + "SLURM_PROCID", + "SLURM_LOCALID", + "SLURM_NODEID", + ) + values = {} + for name in varnames: + values[name] = os.environ.get(name, None) + pylog.debug(f"Env variables:\n{yaml.dump(values, sort_keys=True)}") + + +def teardown_run(cfg: DictConfig, run_start: float, run_end: float) -> None: + total_duration_s = run_end - run_start + total_duration_h = math.floor(total_duration_s / 3600.0) + + subrun_path = get_subrun_path() + if cfg.verbose >= 1: + total_duration_m = (total_duration_s / 60) % 60 + pylog.info( + f"Results are saved in '{subrun_path}' in {total_duration_h:.0f}h{total_duration_m:02.0f}m." + ) + + fpaths = [ + cfg.get("slurm", {}).get("output", None), + cfg.get("slurm", {}).get("error", None), + ] + copy_slurm_logs(fpaths, subrun_path) + + +def load_callbacks( + cfg: DictConfig, + tokenizers: dict[str, Optional[AACTokenizer]], + datamodule: LightningDataModule, + pl_module: LightningModule, +) -> dict[str, Callback]: + callbacks = {} + + resume_callback = ResumeCallback( + resume=cfg.resume, + strict=cfg.strict_resume, + ign_weights=cfg.ign_weights, + verbose=cfg.verbose, + ) + + if cfg.resume_before_setup: + resume_callback.load_checkpoint(pl_module) + + callbacks["resume"] = resume_callback + + # Add callback to stop training if monitor is NaN + early_stop_callback = EarlyStopping( + check_finite=True, + mode=cfg.ckpts[0].mode, + monitor=cfg.ckpts[0].monitor, + patience=sys.maxsize, + ) + callbacks["early_stop"] = early_stop_callback + + print_debug = PrintDebug(cfg.verbose) + callbacks["print_debug"] = print_debug + + # Add Evaluator for compute test metrics scores at the end of the training (when trainer.test is called) + evaluator = instantiate( + cfg.evaluator, + test_tokenizer=tokenizers["test_tokenizer"], + verbose=cfg.verbose, + ) + callbacks["evaluator"] = evaluator + + if hasattr(cfg.dm, "bsize"): + bsize = cfg.dm.bsize + else: + if cfg.verbose >= 0: + pylog.warning("Cannot detect batch size from data conf.") + bsize = None + + log_lr = LogLRCallback(bsize=bsize) + callbacks["log_lr"] = log_lr + + log_grad_norm = LogGradNorm(bsize=bsize) + callbacks["log_grad_norm"] = log_grad_norm + + if cfg.debug: + log_rng_state = LogRngState(bsize=bsize) + callbacks["log_rng_state"] = log_rng_state + + log_gc = LogGCCallback(bsize=bsize) + callbacks["log_gc"] = log_gc + + subrun_path = get_subrun_path() + stats_saver = StatsSaver( + subrun_path=subrun_path, + on_end="none", + tokenizers=tokenizers, + git_hash=cfg.git_hash, + cfg=cfg, + verbose=cfg.verbose, + ) + callbacks["stats_saver"] = stats_saver + + if "swa" in cfg.testing.run: + if datamodule is not None: + datamodule.setup("fit") + pl_module.setup("fit") + + swa_callback = instantiate(cfg.testing.swa) + callbacks["swa"] = swa_callback + + if "ema" in cfg.testing.run: + ema_callback = instantiate(cfg.testing.ema) + callbacks["ema"] = ema_callback + + if cfg.debug: + device_stats_monitor = DeviceStatsMonitor() + callbacks["device_stats_monitor"] = device_stats_monitor + + if cfg.debug or cfg.verbose >= 1: + max_depth = 20 + elif cfg.verbose == 1: + max_depth = 1 + else: + max_depth = 0 + + model_summary = ModelSummary(max_depth=max_depth) + callbacks["model_summary"] = model_summary + + if cfg.enable_dspeed: + deepspeed = DeepSpeedCallback(verbose=cfg.verbose) + callbacks["deepspeed"] = deepspeed + + monitors = [ckpt_cfg.monitor for ckpt_cfg in cfg.ckpts] + validator = AACValidator(monitors, cfg.val_metrics_keys) + callbacks["validator"] = validator + + if cfg.trainer.enable_checkpointing: + ckpts = instantiate(cfg.ckpts) + for i, ckpt in enumerate(ckpts): + callbacks[f"ckpt.{i}"] = ckpt + + callbacks = { + name: callback for name, callback in callbacks.items() if callback is not None + } + return callbacks + + +def test_after_fit( + cfg: DictConfig, + datamodule: LightningDataModule, + pl_module: LightningModule, + evaluator: AACEvaluator, + callbacks: dict[str, Callback], + trainer: Trainer, +) -> None: + testing_run = cfg.testing.run + if isinstance(testing_run, str): + testing_run = [testing_run] + else: + testing_run = list(testing_run) + + if "last" in testing_run: + if cfg.verbose >= 1: + pylog.info("Test using last model...") + + evaluator.set_model_name("last") + trainer.test(pl_module, datamodule=datamodule, verbose=cfg.verbose >= 3) + trainer.predict(pl_module, datamodule=datamodule) + + if "swa" in testing_run: + if cfg.verbose >= 1: + pylog.info("Using SWA weights for testing...") + + evaluator.set_model_name("swa") + trainer.test(pl_module, datamodule=datamodule, verbose=cfg.verbose >= 3) + trainer.predict(pl_module, datamodule=datamodule) + + if "best" in testing_run: + ckpts = trainer.checkpoint_callbacks + n_tests_done = 0 + + for ckpt in ckpts: + if not isinstance(ckpt, ModelCheckpoint) or ckpt.best_model_path == "": + continue + + if cfg.verbose >= 1: + pylog.info( + f"Test using best model file '{osp.basename(ckpt.best_model_path)}'..." + ) + ckpt_data = torch.load( + ckpt.best_model_path, + map_location=pl_module.device, + ) + pl_module.load_state_dict(ckpt_data["state_dict"]) + + if ckpt.monitor is not None: + monitor = ckpt.monitor # type: ignore + monitor = monitor[ + monitor.rfind("/") + 1 : + ] # ex: "val/fense" -> "fense" + model_name = f"best_{monitor}" + else: + model_name = "best" + + evaluator.set_model_name(model_name) + trainer.test(pl_module, datamodule=datamodule, verbose=cfg.verbose >= 3) + trainer.predict(pl_module, datamodule=datamodule) + n_tests_done += 1 + + if n_tests_done == 0: + if "last" not in cfg.testing.run: + pylog.warning( + "Cannot find best checkpoint callback, but testing will be done using last weights." + ) + evaluator.set_model_name("last") + trainer.test(pl_module, datamodule=datamodule, verbose=cfg.verbose >= 3) + trainer.predict(pl_module, datamodule=datamodule) + + else: + pylog.error("Cannot find best checkpoint callback.") + + +@hydra.main( + version_base=None, + config_path=osp.join("..", "conf"), + config_name="train", +) +def main_train(cfg: DictConfig) -> Union[None, float]: + """Train a model on data.""" + run_start = time.perf_counter() + + # --- 1/6 - Set seed, init loggers, print config. + setup_run(cfg) + + # --- 2/6 - Build transforms & tokenizers + audio_tfms_cfgs = { + "train_audio_tfm": cfg.audio_t.train, + "val_audio_tfm": cfg.audio_t.val, + "test_audio_tfm": cfg.audio_t.test, + } + audio_tfms = { + name: instantiate(trans_cfg) for name, trans_cfg in audio_tfms_cfgs.items() + } + audio_tfms: dict[str, Callable] = { + name: trans for name, trans in audio_tfms.items() if trans is not None + } + + train_tokenizers_cfgs = { + "train_tokenizer": cfg.train_tok, + } + train_tokenizers = { + name: instantiate(tok_cfg) for name, tok_cfg in train_tokenizers_cfgs.items() + } + train_tokenizers = { + name: tokenizer + for name, tokenizer in train_tokenizers.items() + if tokenizer is not None + } + + test_tokenizer = instantiate(cfg.test_tok) + test_tokenizers = {"test_tokenizer": test_tokenizer} + tokenizers = train_tokenizers | test_tokenizers + + # --- 3/6 - Build pytorch lightning modules & callbacks + datamodule = instantiate(cfg.dm, **audio_tfms, **train_tokenizers) + pl_module = instantiate(cfg.pl, **train_tokenizers) + + # Callbacks + pl_loggers = [] + + tb_logger = instantiate(cfg.logger) + pl_loggers.append(tb_logger) + + callbacks = load_callbacks(cfg, tokenizers, datamodule, pl_module) + + # --- 4/6 - Build Trainer & run it + fit_trainer: Trainer = instantiate( + cfg.trainer, + logger=pl_loggers, + callbacks=list(callbacks.values()), + ) + + eval_trainer: Optional[Trainer] + if fit_trainer.num_devices == 1: + eval_trainer = fit_trainer + elif fit_trainer.is_global_zero: + eval_trainer = instantiate( + cfg.trainer, + logger=pl_loggers, + num_nodes=1, + devices=1, + callbacks=list(callbacks.values()), + plugins=LightningEnvironment(), + strategy=None, + ) + else: + eval_trainer = None + + if cfg.trainer.auto_scale_batch_size is not None: + # auto_scale_batch_size: None | "power" | "binsearch" + if cfg.verbose >= 1: + pylog.info( + f"Start tuning batch size with mode={cfg.trainer.auto_scale_batch_size}..." + ) + + if not hasattr(datamodule, "TUNE_MODE"): + raise ValueError("DM does not have 'TUNE_MODE' global param.") + + datamodule.TUNE_MODE = True + # Setup dm et plm because tuner needs model to be built on start + datamodule.setup("fit") + pl_module.setup("fit") + + fit_trainer.tune( + pl_module, + datamodule=datamodule, + scale_batch_size_kwargs=dict(init_val=8, batch_arg_name="_bsize"), + ) + return None + + # Validate & test before fit + evaluator: Optional[AACEvaluator] = callbacks.get("evaluator") # type: ignore + + if ( + eval_trainer is not None + and evaluator is not None + and fit_trainer.max_epochs is not None + and fit_trainer.max_epochs > 0 + and ( + fit_trainer.limit_train_batches is None + or fit_trainer.limit_train_batches > 0 + ) + ): + pylog.debug(f"Fit trainer = eval trainer? {fit_trainer is eval_trainer}") + + if cfg.val_on_start: + pylog.debug("Validate on start...") + eval_trainer.validate(pl_module, datamodule=datamodule, verbose=False) + pylog.debug("Validate on start done.") + + if cfg.test_on_start and (cfg.resume is not None or cfg.resume_2 is not None): + pylog.debug("Test on start...") + evaluator.set_model_name("start") + eval_trainer.test(pl_module, datamodule=datamodule, verbose=False) + evaluator.set_model_name("unk") + pylog.debug("Test on start done.") + + # Main training + pylog.debug("Fit model...") + fit_trainer.fit(pl_module, datamodule=datamodule) + pylog.debug("Fit model done.") + + # --- 5/6 - Test checkpoints + # Destroy group for testing on rank 0 after fit when using DDP + if fit_trainer.num_devices > 1: + torch.distributed.destroy_process_group() # type: ignore + + if evaluator is not None and eval_trainer is not None: + pylog.debug(f"Test after fit... (testing={cfg.testing.run})") + test_after_fit(cfg, datamodule, pl_module, evaluator, callbacks, eval_trainer) + pylog.debug(f"Test after fit done. (testing={cfg.testing.run})") + else: + pylog.info("Skip testing after fit.") + + # --- 6/6 - Close files and clean objects + run_end = time.perf_counter() + total_duration_s = run_end - run_start + total_duration_h = total_duration_s / 3600.0 + + stats_saver = callbacks.get("stats_saver") + if isinstance(stats_saver, StatsSaver) and eval_trainer is not None: + stats_saver.save_metrics_stats( + eval_trainer, + pl_module, + datamodule, + add_metrics=dict(total_duration_h=total_duration_h), + ) + + if cfg.out_crit is not None and isinstance(tb_logger, CustomTensorboardLogger): + out = tb_logger.metrics.get(cfg.out_crit, cfg.out_default) + if cfg.verbose >= 1: + pylog.info(f"Training is finished with {cfg.out_crit}={out}.") + else: + out = cfg.out_default + + teardown_run(cfg, run_start, run_end) + return out + + +if __name__ == "__main__": + main_train() diff --git a/src/conette/transforms/audio/__init__.py b/src/conette/transforms/audio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/conette/transforms/audio/cutoutspec.py b/src/conette/transforms/audio/cutoutspec.py new file mode 100644 index 000000000..50ed4059a --- /dev/null +++ b/src/conette/transforms/audio/cutoutspec.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math +import random + +from typing import Iterable, Optional, Union + +import torch + +from torch import nn, Tensor +from torch.distributions import Uniform + + +class CutOutSpec(nn.Module): + def __init__( + self, + freq_size_range: Union[tuple[float, float], tuple[int, int]] = (0.1, 0.5), + time_size_range: Union[tuple[float, float], tuple[int, int]] = (0.1, 0.5), + fill_value: Union[float, tuple[float, float]] = -100.0, + fill_mode: Union[str, nn.Module] = "constant", + freq_dim: int = -1, + time_dim: int = -2, + p: float = 1.0, + ) -> None: + """ + CutOut transform for spectrogram PyTorch tensors. + + Input must be of shape (..., freq, time), but you can specify frequency and time dimension dim/axis. + + Example 1 + ---------- + >>> from aac.transforms.augments.cutoutspec import CutOutSpec + >>> spectrogram = torch.rand((32, 1, 160, 64)) + >>> augment = CutOutSpec((0.5, 0.5), (0.5, 0.5)) + >>> # Remove 25% of the spectrogram values in a squared area + >>> spectrogram_augmented = augment(spectrogram) + + :param freq_size_range: The range of ratios for the frequencies dim. defaults to (0.1, 0.5). + :param time_size_range: The range of ratios for the time steps dim. defaults to (0.1, 0.5). + :param fill_value: The value used for fill. Can be a range of values for sampling the fill value. defaults to -100.0. + This parameter is ignored if fill_mode is a custom Module. + :param fill_mode: The fill mode. defaults to 'constant'. + Can be 'constant', 'random' or a custom transform for the data delimited by the rectange. + :param freq_dim: The dimension index of the spectrogram frequencies. defaults to -1. + :param time_dim: The dimension index of the spectrogram time steps. defaults to -2. + :param p: The probability to apply the transform. default to 1.0. + """ + assert 0.0 <= p <= 1.0 + super().__init__() + + self.freq_size_range = freq_size_range + self.time_size_range = time_size_range + self.fill_value = fill_value + self.fill_mode = fill_mode + self.freq_dim = freq_dim + self.time_dim = time_dim + self.p = p + + self._check_attributes() + + # nn.Module methods + def extra_repr(self) -> str: + hparams = { + "freq_size_range": self.freq_size_range, + "time_size_range": self.time_size_range, + "fill_value": self.fill_value, + "fill_mode": self.fill_mode, + "freq_dim": self.freq_dim, + "time_dim": self.time_dim, + "p": self.p, + } + return ", ".join(f"{k}={v}" for k, v in hparams.items()) + + def forward(self, x: Tensor) -> Tensor: + if self.p >= 1.0 or random.random() < self.p: + return self.apply_transform(x) + else: + return x + + # Other methods + def apply_transform(self, data: Tensor) -> Tensor: + if not isinstance(data, Tensor) or data.ndim < 2: + raise ValueError( + f"Input data must be a PyTorch Tensor with at least 2 dimensions for {self.__class__.__name__} transform, " + f"found {type(data)}" + + (f" of shape {data.shape}" if hasattr(data, "shape") else "") + + "." + ) + + # Prepare slices indexes for frequencies and time dimensions + slices = [slice(None)] * data.ndim + slices[self.freq_dim] = gen_range( + data.shape[self.freq_dim], self.freq_size_range + ) + slices[self.time_dim] = gen_range( + data.shape[self.time_dim], self.time_size_range + ) + + if self.fill_mode == "constant": + data[slices] = self._gen_constant(data[slices]) + + elif self.fill_mode == "random": + data[slices] = self._gen_random(data[slices]) + + elif isinstance(self.fill_mode, nn.Module): + data[slices] = self.fill_mode(data[slices]) + + else: + raise ValueError( + f'Invalid fill_mode "{self.fill_mode}". ' + f'Must be one of "{("constant", "random")}" or a custom transform Module.' + ) + + return data + + def _gen_constant(self, data: Tensor) -> Tensor: + if isinstance(self.fill_value, float): + fill_value = self.fill_value + else: + uniform = Uniform(*self.fill_value) # type: ignore + fill_value = uniform.sample() + return torch.full_like(data, fill_value) + + def _gen_random(self, data: Tensor) -> Tensor: + if isinstance(self.fill_value, float): + raise ValueError( + "Invalid fill_value with random fill_mode. Please use a tuple of 2 floats for fill_value or use " + 'fill_mode="constant".' + ) + else: + uniform = Uniform(*self.fill_value) # type: ignore + return uniform.sample(data.shape) + + def _check_attributes(self) -> None: + if self.freq_dim == self.time_dim: + raise ValueError( + "Frequency dimension index cannot be the same than time dimension index." + ) + + if not isinstance(self.fill_value, float) and not ( + isinstance(self.fill_value, tuple) and len(self.fill_value) == 2 + ): + raise ValueError( + f'Invalid fill_value "{self.fill_value}", must be a float or a tuple of 2 floats.' + ) + + if self.fill_mode == "random" and isinstance(self.fill_value, float): + raise ValueError( + "Invalid fill_value with random fill_mode. Please use a tuple of 2 floats for fill_value or use " + 'fill_mode="constant".' + ) + + +def gen_range( + size: int, + scales: Union[Iterable[float], Iterable[int]], + generator: Optional[torch.Generator] = None, +) -> slice: + """ + Generate an random range in [0, size]. + The position of the range is random. + + :param size: The size of the array. + :param scales: The scales attributes defined the length of the range. + If scales is (float, float), the int length will be sampled from [ ceil(size * scales[0]), ceil(size * scales[1]) ]. + If scales is (int, int), the int length will be sampled from scales. + + Example 1 + ---------- + >>> gen_range(size=100, scales=(0.5, 0.5)) + ... slice(10, 60) + """ + if not isinstance(scales, Iterable): + raise ValueError( + f"Invalid argument {scales=}. (expected tuple[int, int] or tuple[float, float])" + ) + scales = list(scales) + if len(scales) != 2: + raise ValueError( + f"Invalid argument {scales=}. (expected tuple[int, int] or tuple[float, float])" + ) + if not all(isinstance(s, float) for s in scales) and not all( + isinstance(s, int) for s in scales + ): + raise ValueError( + f"Invalid argument {scales=}. (expected tuple[int, int] or tuple[float, float])" + ) + + if isinstance(scales[0], float): + cutout_size_min = math.ceil(scales[0] * size) + cutout_size_max = max(math.ceil(scales[1] * size), cutout_size_min + 1) + elif isinstance(scales[0], int): + cutout_size_min: int = scales[0] # type: ignore + cutout_size_max: int = scales[1] # type: ignore + else: + raise ValueError( + f"Invalid argument {scales=}. (expected tuple[int, int] or tuple[float, float])" + ) + + cutout_size = int( + torch.randint(cutout_size_min, cutout_size_max, (), generator=generator).item() + ) + cutout_start = torch.randint( + 0, max(size - cutout_size + 1, 1), (), generator=generator + ) + cutout_end = cutout_start + cutout_size + assert ( + cutout_end - cutout_start == cutout_size + ), f"{cutout_end} - {cutout_start} != {cutout_size}" + + return slice(cutout_start, cutout_end) diff --git a/src/conette/transforms/audio/resample.py b/src/conette/transforms/audio/resample.py new file mode 100644 index 000000000..a081def99 --- /dev/null +++ b/src/conette/transforms/audio/resample.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math +import random + +from typing import Any, List, Tuple + +import torch + +from torch import nn, Tensor +from torch.distributions import Uniform + + +class ResampleNearest(nn.Module): + def __init__( + self, + rates: Tuple[float, float] = (0.5, 1.5), + dim: int = -1, + p: float = 1.0, + ) -> None: + """Resample an audio waveform signal. + + :param rates: The rate of the stretch. Ex: use 2.0 for multiply the signal length by 2. (default: (0.5, 1.5)) + :param dim: The dimension to modify. (default: -1) + :param p: The probability to apply the transform. (default: 1.0) + """ + assert 0.0 <= p + super().__init__() + self.rates = rates + self.dim = dim + self.p = p + + # nn.Module methods + def extra_repr(self) -> str: + hparams = { + "rates": self.rates, + "dim": self.dim, + "p": self.p, + } + return ", ".join(f"{k}={v}" for k, v in hparams.items()) + + def forward(self, x: Tensor) -> Tensor: + floor_p = math.floor(self.p) + for _ in range(floor_p): + x = self.apply_transform(x) + + rest = self.p - floor_p + if rest > 0.0 and rest < random.random(): + return self.apply_transform(x) + else: + return x + + # Other methods + def apply_transform(self, x: Tensor) -> Tensor: + """Apply the transform without taking into account the probability p.""" + if self.rates[0] == self.rates[1]: + rate = self.rates[0] + else: + sampler = Uniform(*self.rates) + rate = sampler.sample().item() + + x = self._resample_nearest(x, rate) + return x + + def _resample_nearest(self, x: Tensor, rate: float) -> Tensor: + length = x.shape[self.dim] + step = 1.0 / rate + indexes = torch.arange(0, length, step) + indexes = indexes.round().long().clamp(max=length - 1) + slices: List[Any] = [slice(None)] * len(x.shape) + slices[self.dim] = indexes + x = x[slices] + x = x.contiguous() + return x diff --git a/src/conette/transforms/audio/spec_aug.py b/src/conette/transforms/audio/spec_aug.py new file mode 100644 index 000000000..a2dbcadb1 --- /dev/null +++ b/src/conette/transforms/audio/spec_aug.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Provides modules to use SpecAugment +BASED ON https://github.com/qiuqiangkong/sound_event_detection_dcase2017_task4 +MODIFIED: Yes (typing, spectrogram reshape, add probability of specaugment) +""" + +import random + +from typing import Union + +import torch + +from torch import nn, Tensor + + +class DropStripes(nn.Module): + def __init__( + self, + max_width: int, + stripes_num: int, + dim: int, + fill_value: float = 0.0, + inplace: bool = True, + generator: Union[int, torch.Generator, None] = None, + ) -> None: + """Drop stripes. + + :param dim: int, dimension along which to drop + :param drop_width: int, maximum width of stripes to drop + :param stripes_num: int, how many stripes to drop + :param fill_value: the value used to mask stripes + """ + if max_width <= 0: + raise ValueError( + f"Invalid argument {max_width=} in {self.__class__.__name__}. (expected a value > 0)" + ) + + if isinstance(generator, int): + generator = torch.Generator().manual_seed(generator) + + super().__init__() + self.dim = dim + self.max_width = max_width + self.stripes_num = stripes_num + self.fill_value = fill_value + self.inplace = inplace + self.generator = generator + + # nn.Module methods + def extra_repr(self) -> str: + hparams = { + "dim": self.dim, + "max_width": self.max_width, + "stripes_num": self.stripes_num, + "fill_value": self.fill_value, + } + return ", ".join(f"{k}={v}" for k, v in hparams.items()) + + def forward(self, x: Tensor) -> Tensor: + total_width = x.shape[self.dim] + + # Add: If audio is empty, do nothing + if total_width == 0: + return x + + # Add: If audio is shorter than self.drop_width, clip drop width. + max_width = min(self.max_width, total_width) + + widths = torch.randint( + low=0, high=max_width, size=(self.stripes_num,), generator=self.generator + ).tolist() + starts = [ + torch.randint( + low=0, high=total_width - size, size=(), generator=self.generator + ) + for size in widths + ] + + if not self.inplace: + x = x.clone() + + for width, start in zip(widths, starts): + slices = [slice(None) for _ in range(x.ndim)] + slices[self.dim] = slice(start, start + width) + x[slices] = self.fill_value + + return x + + +class SpecAugment(nn.Module): + def __init__( + self, + time_max_width: int, + time_stripes_num: int, + freq_max_width: int, + freq_stripes_num: int, + time_dim: int = -2, + freq_dim: int = -1, + fill_value: float = 0.0, + p: float = 1.0, + ) -> None: + """Spec augmentation. + [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. + and Le, Q.V., 2019. Specaugment: A simple data augmentation method + for automatic speech recognition. arXiv preprint arXiv:1904.08779. + + Args: + time_drop_width: int + time_stripes_num: int + freq_drop_width: int + freq_stripes_num: int + """ + assert 0.0 <= p <= 1.0 + super().__init__() + self.p = p + + self.time_dropper = DropStripes( + max_width=time_max_width, + stripes_num=time_stripes_num, + dim=time_dim, + fill_value=fill_value, + ) + self.freq_dropper = DropStripes( + max_width=freq_max_width, + stripes_num=freq_stripes_num, + dim=freq_dim, + fill_value=fill_value, + ) + + # nn.Module methods + def extra_repr(self) -> str: + hparams = { + "p": self.p, + } + return ", ".join(f"{k}={v}" for k, v in hparams.items()) + + def forward(self, x: Tensor) -> Tensor: + if self.p >= 1.0 or random.random() < self.p: + return self.apply_transform(x) + else: + return x + + # Other methods + def apply_transform(self, x: Tensor) -> Tensor: + x = self.time_dropper(x) + x = self.freq_dropper(x) + return x + + +class DropStripesRatio(nn.Module): + def __init__( + self, + ratios: tuple[float, float], + stripes_num: int, + dim: int, + fill_value: float = 0.0, + generator: Union[int, torch.Generator, None] = None, + inplace: bool = True, + ) -> None: + if not (0.0 <= ratios[0] <= ratios[1] <= 1.0): + raise ValueError( + f"Invalid argument {ratios=}. (expected a tuple of two floats in [0, 1], with ratios[0] <= ratios[1])" + ) + + if isinstance(generator, int): + generator = torch.Generator().manual_seed(generator) + + super().__init__() + self.ratios = ratios + self.stripes_num = stripes_num + self.dim = dim + self.fill_value = fill_value + self.generator = generator + self.inplace = inplace + + # nn.Module methods + def extra_repr(self) -> str: + hparams = { + "dim": self.dim, + "max_width": self.max_width, + "stripes_num": self.stripes_num, + "fill_value": self.fill_value, + } + return ", ".join(f"{k}={v}" for k, v in hparams.items()) + + def forward(self, x: Tensor) -> Tensor: + total_width = x.shape[self.dim] + # If audio is empty, do nothing + if total_width == 0: + return x + imin = round(total_width * self.ratios[0]) + imax = round(total_width * self.ratios[1]) + + if imin > imax: + return x + elif imin == imax: + widths = torch.full((self.stripes_num,), imin) + else: + widths = torch.randint( + imin, imax, (self.stripes_num,), generator=self.generator + ) + + starts = [ + torch.randint(low=0, high=total_width - size, size=(), generator=self.generator) for size in widths # type: ignore + ] + + if not self.inplace: + x = x.clone() + + for width, start in zip(widths, starts): + slices = [slice(None) for _ in range(x.ndim)] + slices[self.dim] = slice(start, start + width) + x[slices] = self.fill_value + + return x + + +class SpecAugmentRatio(nn.Module): + def __init__( + self, + time_ratios: tuple[float, float], + time_stripes_num: int, + freq_ratios: tuple[float, float], + freq_stripes_num: int, + time_dim: int = -2, + freq_dim: int = -1, + fill_value: float = 0.0, + inplace: bool = True, + p: float = 1.0, + ) -> None: + assert 0.0 <= p <= 1.0 + super().__init__() + self.p = p + + self.time_dropper = DropStripesRatio( + ratios=time_ratios, + stripes_num=time_stripes_num, + dim=time_dim, + fill_value=fill_value, + inplace=inplace, + ) + self.freq_dropper = DropStripesRatio( + ratios=freq_ratios, + stripes_num=freq_stripes_num, + dim=freq_dim, + fill_value=fill_value, + inplace=inplace, + ) + + def forward(self, x: Tensor) -> Tensor: + if self.p >= 1.0 or random.random() < self.p: + return self.apply_transform(x) + else: + return x + + def apply_transform(self, x: Tensor) -> Tensor: + x = self.time_dropper(x) + x = self.freq_dropper(x) + return x diff --git a/src/conette/transforms/audio/speed_perturb.py b/src/conette/transforms/audio/speed_perturb.py new file mode 100644 index 000000000..4ab2c1d11 --- /dev/null +++ b/src/conette/transforms/audio/speed_perturb.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import math +import random + +from typing import Tuple, Union + +from torch import nn, Tensor + +from conette.nn.modules.crop import CropDim +from conette.nn.modules.pad import PadDim +from conette.transforms.audio.resample import ResampleNearest + + +class SpeedPerturbation(nn.Module): + def __init__( + self, + rates: Tuple[float, float] = (0.9, 1.1), + target_length: Union[int, str, None] = "same", + align: str = "random", + fill_value: float = 0.0, + dim: int = -1, + p: float = 1.0, + ) -> None: + """Resample, Pad and Crop a signal. + + :param rates: The ratio of the signal used for resize. defaults to (0.9, 1.1). + :param target_length: Optional target length of the signal dimension. + If 'same', the output will have the same shape than the input. + defaults to "same". + :param align: Alignment to use for cropping and padding. Can be 'left', 'right', 'center' or 'random'. + defaults to "random". + :param fill_value: The fill value when padding the waveform. defaults to 0.0. + :param dim: The dimension to stretch and pad or crop. defaults to -1. + :param p: The probability to apply the transform. defaults to 1.0. + """ + assert 0.0 <= p + rates = tuple(rates) # type: ignore + + super().__init__() + self.rates = rates + self._target_length = target_length + self.align = align + self.fill_value = fill_value + self.dim = dim + self.p = p + + target_length = self.target_length if isinstance(self.target_length, int) else 1 + self.resampler = ResampleNearest(rates, dim=dim) + self.pad = PadDim(target_length, align, fill_value, dim, mode="constant") + self.crop = CropDim(target_length, align, dim) + + # nn.Module methods + def extra_repr(self) -> str: + hparams = { + "rates": self.rates, + "target_length": self.target_length, + "align": self.align, + "fill_value": self.fill_value, + "dim": self.dim, + "p": self.p, + } + return ", ".join(f"{k}={v}" for k, v in hparams.items()) + + def forward(self, x: Tensor) -> Tensor: + floor_p = math.floor(self.p) + for _ in range(floor_p): + x = self.apply_transform(x) + + rest = self.p - floor_p + if rest > 0.0 and rest < random.random(): + return self.apply_transform(x) + else: + return x + + # Other methods + def apply_transform(self, x: Tensor) -> Tensor: + """Apply the transform without taking into account the probability p.""" + if self.target_length == "same": + target_length = x.shape[self.dim] + self.pad.target_length = target_length + self.crop.target_length = target_length + + x = self.resampler(x) + + if self.target_length is not None: + x = self.pad(x) + x = self.crop(x) + return x + + @property + def target_length(self) -> Union[int, str, None]: + return self._target_length diff --git a/src/conette/transforms/audioset_labels.py b/src/conette/transforms/audioset_labels.py new file mode 100644 index 000000000..d78a55172 --- /dev/null +++ b/src/conette/transforms/audioset_labels.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv +import os +import os.path as osp + +from pathlib import Path +from typing import Union + +import torch + +from torch import Tensor +from torch.hub import download_url_to_file + + +AUDIOSET_INFOS = { + "class_labels_indices": { + "fname": "class_labels_indices.csv", + "url": "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv", + }, +} + + +def probs_to_binarized( + probs: Tensor, + threshold: Union[float, Tensor], +) -> Tensor: + """Perform thresholding to binarize probabilities.""" + if probs.ndim != 2: + raise ValueError( + "Invalid argument probs. (expected a batch of probabilities of shape (N, n_classes))." + ) + nb_classes = probs.shape[1] + + if isinstance(threshold, Tensor) and threshold.ndim == 1: + threshold = threshold.item() + + if isinstance(threshold, (float, int)): + threshold = torch.full( + (nb_classes,), threshold, dtype=torch.float, device=probs.device + ) + else: + if threshold.shape[1] != nb_classes: + raise ValueError("Invalid argument threshold.") + threshold = threshold.to(device=probs.device) + + binarized = probs >= threshold + return binarized + + +def binarized_to_indices( + binarized: Tensor, +) -> list[list[int]]: + """Convert binarized probs to list of indexes.""" + preds = [] + for binarized_i in binarized: + preds_i = torch.where(binarized_i)[0].tolist() + preds.append(preds_i) + return preds + + +def probs_to_indices( + probs: Tensor, + threshold: Union[float, Tensor], +) -> list[list[int]]: + """Convert probs to list of indexes.""" + binarized = probs_to_binarized(probs, threshold) + preds = binarized_to_indices(binarized) + return preds + + +def probs_to_labels( + probs: Tensor, + threshold: Union[float, Tensor], + offline: bool = False, + verbose: int = 0, +) -> list[list[str]]: + """Convert probs to list of labels.""" + indices = probs_to_indices(probs, threshold) + labels = indices_to_labels(indices, offline, verbose) + return labels + + +def indices_to_labels( + indices: Union[list[list[int]], list[Tensor]], + offline: bool = False, + verbose: int = 0, +) -> list[list[str]]: + """Convert indices to list of labels.""" + name_to_idx = load_audioset_mapping(offline, verbose) + idx_to_name = {idx: name for name, idx in name_to_idx.items()} + + labels = [] + for indices_i in indices: + names = [idx_to_name[idx] for idx in indices_i] # type: ignore + labels.append(names) + return labels + + +def get_audioset_mapping_dir_path() -> Path: + dpath = Path.home().joinpath(".cache", "conette") + return dpath + + +def load_audioset_mapping(offline: bool = False, verbose: int = 0) -> dict[str, int]: + info = AUDIOSET_INFOS["class_labels_indices"] + dpath = get_audioset_mapping_dir_path() + + map_fname = info["fname"] + map_fpath = dpath.joinpath(map_fname) + + if not osp.isfile(map_fpath): + if offline: + raise FileNotFoundError( + f"Cannot find or download audioset mapping file in '{map_fpath}' with mode {offline=}." + ) + + download_audioset_mapping(verbose) + + with open(map_fpath, "r") as file: + reader = csv.DictReader(file, skipinitialspace=True, strict=True) + data = list(reader) + + name_to_index = {info["display_name"]: int(info["index"]) for info in data} + return name_to_index + + +def download_audioset_mapping(verbose: int = 0) -> None: + info = AUDIOSET_INFOS["class_labels_indices"] + dpath = get_audioset_mapping_dir_path() + map_fname = info["fname"] + map_fpath = dpath.joinpath(map_fname) + + url = info["url"] + os.makedirs(dpath, exist_ok=True) + download_url_to_file(url, str(map_fpath), progress=verbose >= 1) diff --git a/src/conette/transforms/get.py b/src/conette/transforms/get.py new file mode 100644 index 000000000..1822dd567 --- /dev/null +++ b/src/conette/transforms/get.py @@ -0,0 +1,670 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os.path as osp +import pickle + +from pathlib import Path +from typing import Any, Optional, Union + +import torch +import yaml + +from nnAudio.features import Gammatonegram +from torch import nn, Tensor +from torchaudio.transforms import Resample +from torchlibrosa.stft import Spectrogram, LogmelFilterBank + +from conette.nn.encoders.convnext import convnext_tiny +from conette.nn.encoders.cnn10 import Cnn10 +from conette.nn.encoders.cnn14_decisionlevel_att import Cnn14_DecisionLevelAtt +from conette.nn.encoders.cnn14 import Cnn14 +from conette.nn.functional.get import get_device +from conette.nn.modules.misc import ( + Lambda, + Standardize, +) +from conette.nn.modules.tensor import ( + Mean, + Permute, + Squeeze, + TensorTo, + Unsqueeze, +) +from conette.nn.pann_utils.ckpt import pann_load_state_dict +from conette.transforms.audio.spec_aug import SpecAugment + + +pylog = logging.getLogger(__name__) + + +def get_none() -> None: + # Returns None. Can be used for hydra instantiations. + return None + + +def get_pickle( + fpath: Union[str, Path], +) -> nn.Module: + if not isinstance(fpath, (str, Path)): + raise TypeError(f"Invalid transform with pickle {fpath=}. (not a str or Path)") + if not osp.isfile(fpath): + raise FileNotFoundError(f"Invalid transform with pickle {fpath=}. (not a file)") + + with open(fpath, "rb") as file: + transform = pickle.load(file) + return transform + + +def get_resample_mean( + src_sr: int, + tgt_sr: int, + mean_dim: Optional[int] = 0, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + return nn.Sequential( + Resample(src_sr, tgt_sr), + Mean(dim=mean_dim), + ) + + +def get_resample_mean_cnn10( + src_sr: int, + tgt_sr: int, + mean_dim: Optional[int] = 0, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + device: Union[str, torch.device, None] = "auto", + transpose_frame_embs: bool = True, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + device = get_device(device) + + encoder = Cnn10( + sr=tgt_sr, + window_size=window_size, + hop_size=hop_size, + mel_bins=mel_bins, + return_clip_outputs=True, + return_frame_outputs=False, + pretrained=False, + waveform_input=True, + ) + for p in encoder.parameters(): + p.requires_grad_(False) + encoder.eval() + + state_dict = pann_load_state_dict("Cnn10", "cpu", offline=False) + encoder.load_state_dict(state_dict) + + encoder = encoder.to(device) + + def get_cnn10_embs(wave: Tensor) -> dict[str, Tensor]: + wave = wave.unsqueeze_(dim=0) + wave_shape = torch.as_tensor(wave.shape[1:]).unsqueeze_(dim=0) + out = encoder(wave, wave_shape) + + if transpose_frame_embs: + # Transpose (n_channels=1, features=512, time) -> (n_channels=1, time, features=512) + out["frame_embs"] = out["frame_embs"].transpose(1, 2) + + return out + + return nn.Sequential( + Resample(src_sr, tgt_sr), + Mean(dim=mean_dim), + Unsqueeze(dim=0), + TensorTo(device=device), + Lambda(get_cnn10_embs), + ) + + +def get_resample_mean_cnn14_att( + src_sr: int, + tgt_sr: int, + mean_dim: Optional[int] = 0, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + device: Union[str, torch.device, None] = "auto", + transpose_frame_embs: bool = True, + only_frame_embs: bool = True, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + device = get_device(device) + + encoder = Cnn14_DecisionLevelAtt( + sr=tgt_sr, + window_size=window_size, + hop_size=hop_size, + mel_bins=mel_bins, + return_clip_outputs=True, + pretrained=True, + ) + for p in encoder.parameters(): + p.requires_grad_(False) + encoder.eval() + encoder = encoder.to(device) + + def get_cnn14_embs(wave: Tensor) -> Any: + # Add batch dim + wave = wave.unsqueeze_(dim=0) + wave_shape = torch.as_tensor(wave.shape[1:]).unsqueeze_(dim=0) + out = encoder(wave, wave_shape) + frame_embs = out.pop("frame_embs") + + if transpose_frame_embs: + # Transpose (n_channels, features=2048, time) -> (n_channels, time, features=2048) + frame_embs = frame_embs.transpose(1, 2) + + if only_frame_embs: + return frame_embs + else: + out["frame_embs"] = frame_embs + return out + + return nn.Sequential( + Resample(src_sr, tgt_sr), + Mean(dim=mean_dim), + TensorTo(device=device), + Lambda(get_cnn14_embs), + ) + + +def get_resample_mean_cnn14( + src_sr: int, + tgt_sr: int, + mean_dim: Optional[int] = 0, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + device: Union[str, torch.device, None] = "auto", + transpose_frame_embs: bool = True, + only_frame_embs: bool = True, + pretrain_path: Optional[str] = None, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + device = get_device(device) + encoder = Cnn14( + sample_rate=tgt_sr, + window_size=window_size, + hop_size=hop_size, + mel_bins=mel_bins, + waveform_input=True, + use_specaug=False, + return_clip_outputs=True, + return_frame_outputs=True, + ) + + if pretrain_path is None: + state_dict = pann_load_state_dict("Cnn14", device, offline=False) + else: + state_dict = pann_load_state_dict(pretrain_path, device) + encoder.load_state_dict(state_dict) + + for p in encoder.parameters(): + p.requires_grad_(False) + encoder.eval() + encoder = encoder.to(device=device) + + def get_cnn14_embs(wave: Tensor) -> Any: + # Add batch dim + wave = wave.unsqueeze_(dim=0) + wave_shape = torch.as_tensor(wave.shape[1:]).unsqueeze_(dim=0) + out = encoder(wave, wave_shape) + frame_embs = out.pop("frame_embs") + + if transpose_frame_embs: + # Transpose (n_channels, features=2048, time) -> (n_channels, time, features=2048) + frame_embs = frame_embs.transpose(1, 2) + + if only_frame_embs: + return frame_embs + else: + # note: empty string will use "audio" column in HDF instead of "audio.frame_embs". + # see aac/datasets/hdf/pack.py source code for details. + out[""] = frame_embs + return out + + return nn.Sequential( + Resample(src_sr, tgt_sr), + Mean(dim=mean_dim), + TensorTo(device=device), + Lambda(get_cnn14_embs), + ) + + +def get_resample_mean_convnext( + src_sr: int, + tgt_sr: int, + mean_dim: Optional[int] = 0, + device: Union[str, torch.device, None] = "auto", + transpose_frame_embs: bool = True, + only_frame_embs: bool = True, + pretrain_path: Optional[str] = None, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + if not isinstance(pretrain_path, str): + raise ValueError( + f"Invalid argument type {type(pretrain_path)=}. (expected str)" + ) + + device = get_device(device) + encoder = convnext_tiny( + pretrained=False, + strict=False, + drop_path_rate=0.0, + after_stem_dim=[252, 56], + use_speed_perturb=False, + waveform_input=True, + use_specaug=False, + return_clip_outputs=True, + return_frame_outputs=True, + ) + + data = torch.load(pretrain_path, map_location=torch.device("cpu")) + state_dict = data["model"] + encoder.load_state_dict(state_dict, strict=False) + + for p in encoder.parameters(): + p.requires_grad_(False) + encoder.eval() + encoder = encoder.to(device=device) + + def get_model_outputs(wave: Tensor) -> Any: + # Add batch dim + wave = wave.unsqueeze_(dim=0) + wave_shape = torch.as_tensor(wave.shape[1:]).unsqueeze_(dim=0) + out = encoder(wave, wave_shape) + frame_embs = out.pop("frame_embs") + + if transpose_frame_embs: + # Transpose (n_channels, features=768, time) -> (n_channels, time, features=768) + frame_embs = frame_embs.transpose(1, 2) + + if only_frame_embs: + return frame_embs + else: + # note: empty string will use "audio" column in HDF instead of "audio.frame_embs". + # see aac/datasets/hdf/pack.py source code for details. + out[""] = frame_embs + return out + + return nn.Sequential( + Resample(src_sr, tgt_sr), + Mean(dim=mean_dim), + TensorTo(device=device), + Lambda(get_model_outputs), + ) + + +def get_resample_mean_spec( + src_sr: int, + tgt_sr: int, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + window: str = "hann", + center: bool = True, + pad_mode: str = "reflect", + ref: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = None, + freeze_parameters: bool = True, + mean_dim: Optional[int] = 0, + device: Union[str, torch.device, None] = "auto", +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + device = get_device(device) + + to_spectro = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ) + to_logmel = LogmelFilterBank( + sr=tgt_sr, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=freeze_parameters, + ) + + to_spectro = to_spectro.to(device=device) + to_logmel = to_logmel.to(device=device) + + transform = nn.Sequential( + Resample(src_sr, tgt_sr), + Mean(dim=mean_dim), + Unsqueeze(dim=0), + TensorTo(device=device), + to_spectro, + to_logmel, + Squeeze(dim=0), + ) + return transform + + +def get_resample_spec_mean_spec_aug( + src_sr: int, + tgt_sr: int, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + window: str = "hann", + center: bool = True, + pad_mode: str = "reflect", + ref: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = None, + freeze_parameters: bool = True, + mean_dim: Optional[int] = 0, + time_drop_width: int = 64, + time_stripes_num: int = 2, + freq_drop_width: int = 2, + freq_stripes_num: int = 1, + spec_aug_p: float = 1.0, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + return nn.Sequential( + Resample(src_sr, tgt_sr), + Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ), + LogmelFilterBank( + sr=tgt_sr, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=freeze_parameters, + ), + Mean(dim=mean_dim), + SpecAugment( + time_max_width=time_drop_width, + time_stripes_num=time_stripes_num, + freq_max_width=freq_drop_width, + freq_stripes_num=freq_stripes_num, + p=spec_aug_p, + ), + ) + + +def get_resample_spec_mean( + src_sr: int, + tgt_sr: int, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + window: str = "hann", + center: bool = True, + pad_mode: str = "reflect", + ref: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = None, + freeze_parameters: bool = True, + mean_dim: Optional[int] = 0, + device: Union[str, torch.device, None] = "auto", +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + device = get_device(device) + + to_spectro = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ) + to_logmel = LogmelFilterBank( + sr=tgt_sr, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=freeze_parameters, + ) + + to_spectro = to_spectro.to(device=device) + to_logmel = to_logmel.to(device=device) + + transform = nn.Sequential( + Resample(src_sr, tgt_sr), + TensorTo(device=device), + to_spectro, + to_logmel, + Mean(dim=mean_dim), + ) + return transform + + +def get_resample_mean_gamma_perm( + src_sr: int, + tgt_sr: int, + mean_dim: int = 0, + n_fft: int = 1024, + n_bins: int = 64, + hop_length: int = 512, + window: str = "hann", + center: bool = True, + pad_mode: str = "reflect", + power: float = 2.0, + htk: bool = False, + fmin: float = 20.0, + fmax: Optional[float] = None, + norm: float = 1, + trainable_bins: bool = False, + trainable_STFT: bool = False, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + return nn.Sequential( + Resample(src_sr, tgt_sr), + Mean(dim=mean_dim), + Gammatonegram( + sr=tgt_sr, + n_fft=n_fft, + n_bins=n_bins, + hop_length=hop_length, + window=window, + center=center, + pad_mode=pad_mode, + power=power, + htk=htk, + fmin=fmin, + fmax=fmax, + norm=norm, # type: ignore + trainable_bins=trainable_bins, + trainable_STFT=trainable_STFT, + verbose=False, + ), + Permute(0, 2, 1), + ) + + +def get_stand_resample_spectro_mean_spec_aug( + src_sr: int, + tgt_sr: int, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + window: str = "hann", + center: bool = True, + pad_mode: str = "reflect", + ref: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = None, + freeze_parameters: bool = True, + mean_dim: Optional[int] = 0, + time_drop_width: int = 64, + time_stripes_num: int = 2, + freq_drop_width: int = 2, + freq_stripes_num: int = 1, + spec_aug_p: float = 1.0, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + return nn.Sequential( + Standardize(), + Resample(src_sr, tgt_sr), + Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ), + LogmelFilterBank( + sr=tgt_sr, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=freeze_parameters, + ), + Mean(dim=mean_dim), + SpecAugment( + time_max_width=time_drop_width, + time_stripes_num=time_stripes_num, + freq_max_width=freq_drop_width, + freq_stripes_num=freq_stripes_num, + p=spec_aug_p, + ), + ) + + +def get_stand_resample_spectro_mean( + src_sr: int, + tgt_sr: int, + window_size: int = 1024, + hop_size: int = 320, + mel_bins: int = 64, + fmin: int = 50, + fmax: int = 14000, + window: str = "hann", + center: bool = True, + pad_mode: str = "reflect", + ref: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = None, + freeze_parameters: bool = True, + mean_dim: Optional[int] = 0, +) -> nn.Sequential: + if not isinstance(src_sr, int): + error_message = _get_error_message(src_sr) + pylog.error(error_message) + raise ValueError(error_message) + + return nn.Sequential( + Standardize(), + Resample(src_sr, tgt_sr), + Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ), + LogmelFilterBank( + sr=tgt_sr, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, # type: ignore + freeze_parameters=freeze_parameters, + ), + Mean(dim=mean_dim), + ) + + +def _get_error_message(src_sr: Any) -> str: + defaults_srs = {"clotho": 44100, "audiocaps": 32000, "macs": 48000} + defaults_srs = yaml.dump(defaults_srs, sort_keys=False) + message = ( + "\n" + f"Invalid sr={src_sr} for get_resample_mean() function.\n" + f"Please specify explicitely the source sample rate in Hz with audio_t.src_sr=SAMPLE_RATE.\n" + f"BE CAREFUL, sample rate can be different if you use pre-processed HDF files.\n" + f"Defaults sample rates are:\n{defaults_srs}" + ) + return message diff --git a/src/conette/transforms/mixup.py b/src/conette/transforms/mixup.py index 370aebe51..43117f028 100644 --- a/src/conette/transforms/mixup.py +++ b/src/conette/transforms/mixup.py @@ -1,11 +1,13 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import random + from typing import Any, Iterable, Union import torch -from torch import Tensor +from torch import nn, Tensor from torch.distributions.beta import Beta @@ -53,3 +55,76 @@ def sample_lambda( if asymmetric: lbd = torch.max(lbd, 1.0 - lbd) return lbd + + +class Mixup(nn.Module): + """ + Mix linearly inputs with coefficient sampled from a Beta distribution. + """ + + def __init__( + self, + alpha: float = 0.4, + asymmetric: bool = False, + p: float = 1.0, + ) -> None: + """ + ``` + lambda ~ Beta(alpha, alpha) + x = lambda * x + (1.0 - lambda) * shuffle(x) + y = lambda * y + (1.0 - lambda) * shuffle(y) + ``` + + :param alpha: The parameter used by the beta distribution. + If alpha -> 0, the value sampled will be close to 0 or 1. + If alpha -> 1, the value will be sampled from a uniform distribution. + defaults to 0.4. + :param asymmetric: If True, the first coefficient will always be the higher one, which means the result will be closer to the input. + defaults to False. + :param p: The probability to apply the mixup. + defaults to 1.0. + """ + assert 0.0 <= p <= 1.0 + super().__init__() + self.alpha = alpha + self.asymmetric = asymmetric + self.p = p + + self.beta = Beta(alpha, alpha) + + # nn.Module methods + def extra_repr(self) -> str: + hparams = { + "alpha": self.alpha, + "asymmetric": self.asymmetric, + "p": self.p, + } + return ", ".join(f"{k}={v}" for k, v in hparams.items()) + + def __call__(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + # This method is here only for typing + return super().__call__(x, y) + + def forward(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + if self.p >= 1.0 or random.random() < self.p: + return self.apply_transform(x, y) + else: + return x, y + + # Other methods + def sample_lambda(self, size: Iterable[int] = ()) -> Tensor: + return sample_lambda(self.alpha, self.asymmetric, size) + + def apply_transform(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + if x.shape[0] != y.shape[0]: + raise ValueError( + f"Data to mix must have the same size along the first dim. ({x.shape[0]=} != {y.shape[0]=})" + ) + + bsize = x.shape[0] + lbd = self.sample_lambda(()) + indexes = torch.randperm(bsize) + + x = x * lbd + x[indexes] * (1.0 - lbd) + y = y * lbd + y[indexes] * (1.0 - lbd) + return x, y diff --git a/src/conette/transforms/utils.py b/src/conette/transforms/utils.py new file mode 100644 index 000000000..d8cd706bf --- /dev/null +++ b/src/conette/transforms/utils.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Any, Callable, Iterable, Mapping, Optional + +from torch import Tensor + +from conette.utils.misc import pass_filter + + +class DictTransform(dict[str, Optional[Callable]]): + def __init__( + self, + transforms_dict: Optional[Mapping[str, Optional[Callable]]] = None, + **transforms_kwargs: Optional[Callable], + ) -> None: + """Wrap a dictionary of transforms to apply to each value of a dictionary input at a corresponding key. + + Example 1 + ---------- + ```py + >>> triple_a = DictTransform({"a": lambda x: x * 3}) + >>> input = {"a": 4, "b": 5} + >>> triple_a(input) + ... {"a": 12, "b": 5} + ``` + """ + if transforms_dict is None: + transforms_dict = {} + else: + transforms_dict = dict(transforms_dict) + transforms = transforms_dict | transforms_kwargs + super().__init__(transforms) + + def forward(self, item: dict[str, Any]) -> dict[str, Any]: + for name, transform in self.items(): + if transform is not None and name in item: + item[name] = transform(item[name]) + return item + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + +class ShapesToSizes: + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def __call__(self, x_shapes: Tensor) -> Tensor: + return x_shapes[:, self.dim] + + +class SelectColumns: + def __init__( + self, + /, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + ) -> None: + super().__init__() + self._include = include + self._exclude = exclude + + def __call__(self, item: Mapping[str, Any]) -> dict[str, Any]: + item = { + k: v + for k, v in item.items() + if pass_filter(k, self._include, self._exclude) + } + return item + + +class Rename: + def __init__(self, **kwargs: str) -> None: + super().__init__() + self.renames = kwargs + + def __call__(self, item: Mapping[str, Any]) -> dict[str, Any]: + item = {self.renames.get(k, k): v for k, v in item.items()} + return item + + +class Compose: + def __init__(self, *fns: Callable) -> None: + super().__init__() + self.fns = fns + + def __call__(self, x: Any) -> Any: + for fn in self.fns: + x = fn(x) + return x diff --git a/src/conette/utils/cmdline.py b/src/conette/utils/cmdline.py index 698c5d4d0..bb3203fbe 100644 --- a/src/conette/utils/cmdline.py +++ b/src/conette/utils/cmdline.py @@ -4,9 +4,9 @@ import logging import sys -from typing import Optional - +from typing import Optional, Union +_NONE_VALUES = ("none",) _TRUE_VALUES = ("true", "1", "t", "yes", "y") _FALSE_VALUES = ("false", "0", "f", "no", "n") @@ -23,9 +23,20 @@ def _str_to_bool(s: str) -> bool: ) +def _str_to_union_bool_str(s: str) -> Union[bool, str]: + s = str(s) + + if s.lower() in _TRUE_VALUES: + return True + elif s.lower() in _FALSE_VALUES: + return False + else: + return s + + def _str_to_opt_str(s: str) -> Optional[str]: s = str(s) - if s.lower() == "none": + if s.lower() in _NONE_VALUES: return None else: return s @@ -33,7 +44,7 @@ def _str_to_opt_str(s: str) -> Optional[str]: def _str_to_opt_int(s: str) -> Optional[int]: s = str(s) - if s.lower() == "none": + if s.lower() in _NONE_VALUES: return None else: return int(s) diff --git a/src/conette/utils/collections.py b/src/conette/utils/collections.py index 685cbdad2..209f7d1e2 100644 --- a/src/conette/utils/collections.py +++ b/src/conette/utils/collections.py @@ -1,14 +1,22 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import re + from typing import ( + Any, Callable, Iterable, + Mapping, Optional, + Sequence, TypeVar, + Union, overload, ) +import numpy as np + T = TypeVar("T") U = TypeVar("U") V = TypeVar("V") @@ -25,6 +33,249 @@ def all_eq(it: Iterable[T], eq_fn: Optional[Callable[[T, T], bool]] = None) -> b return all(eq_fn(first, elt) for elt in it) +def all_ne(it: Iterable[T], ne_fn: Optional[Callable[[T, T], bool]] = None) -> bool: + """Returns true if all elements in inputs are differents.""" + it = list(it) + if ne_fn is None: + return all( + it[i] != it[j] for i in range(len(it)) for j in range(i + 1, len(it)) + ) + else: + return all( + ne_fn(it[i], it[j]) for i in range(len(it)) for j in range(i + 1, len(it)) + ) + + +def list_dict_to_dict_list( + lst: Sequence[dict[str, T]], + default_val: U = None, + error_on_missing_key: bool = False, +) -> dict[str, list[Union[T, U]]]: + """Convert a list of dicts to a dict of lists. + + Example 1 + ---------- + >>> lst = [{'a': 1, 'b': 2}, {'a': 4, 'b': 3, 'c': 5}] + >>> output = list_dict_to_dict_list(lst, default_val=0) + {'a': [1, 4], 'b': [2, 3], 'c': [0, 5]} + """ + if len(lst) == 0: + return {} + + if error_on_missing_key: + keys = set(lst[0]) + for dic in lst: + if keys != set(dic.keys()): + raise ValueError( + f"Invalid dict keys for list_dict_to_dict_list. (found {keys} and {dic.keys()})" + ) + + keys = {} + for dic in lst: + keys = keys | dict.fromkeys(dic.keys()) + + out = { + key: [ + lst[i][key] if key in lst[i].keys() else default_val + for i in range(len(lst)) + ] + for key in keys + } + return out + + +def dict_list_to_list_dict(dic: dict[str, list[T]]) -> list[dict[str, T]]: + """Convert dict of lists with same sizes to list of dicts. + + Example 1 + ---------- + ``` + >>> dic = {"a": [1, 2], "b": [3, 4]} + >>> dict_list_to_list_dict(dic) + ... [{"a": 1, "b": 3}, {"a": 2, "b": 4}] + ``` + """ + assert all_eq(map(len, dic.values())) + length = len(next(iter(dic.values()))) + return [{k: v[i] for k, v in dic.items()} for i in range(length)] + + +def flat_dict_of_dict( + nested_dic: Mapping[str, Any], + sep: str = ".", + flat_iterables: bool = False, +) -> dict[str, Any]: + """Flat a nested dictionary. + Example + ---------- + ``` + >>> dic = { + ... "a": 1, + ... "b": { + ... "a": 2, + ... "b": 10, + ... }, + ... } + >>> flat_dict(dic) + ... {"a": 1, "b.a": 2, "b.b": 10} + ``` + """ + output = {} + for k, v in nested_dic.items(): + if isinstance(v, Mapping) and all(isinstance(kv, str) for kv in v.keys()): + v = flat_dict_of_dict(v, sep, flat_iterables) + output |= {f"{k}{sep}{kv}": vv for kv, vv in v.items()} + elif flat_iterables and isinstance(v, Iterable) and not isinstance(v, str): + output |= { + f"{k}{sep}{i}": flat_dict_of_dict(vv, sep, flat_iterables) + for i, vv in enumerate(v) + } + else: + output[k] = v + return output + + +def unflat_dict_of_dict(dic: Mapping[str, Any], sep: str = ".") -> dict[str, Any]: + """Unflat a dictionary. + + Example 1 + ---------- + ``` + >>> dic = { + "a.a": 1, + "b.a": 2, + "b.b": 3, + "c": 4, + } + >>> unflat_dict(dic) + ... {"a": {"a": 1}, "b": {"a": 2, "b": 3}, "c": 4} + ``` + """ + output = {} + for k, v in dic.items(): + if sep not in k: + output[k] = v + else: + idx = k.index(sep) + k, kk = k[:idx], k[idx + 1 :] + if k not in output: + output[k] = {} + elif not isinstance(output[k], Mapping): + raise ValueError( + f"Invalid dict argument. (found keys {k} and {k}{sep}{kk})" + ) + + output[k][kk] = v + + output = { + k: (unflat_dict_of_dict(v) if isinstance(v, Mapping) else v) + for k, v in output.items() + } + return output + + +def flat_list_rec( + nested_lst: Union[list, tuple], + returns_shapes: bool = False, +) -> Union[list, tuple]: + """Flat nested list to list of scalars.""" + if not isinstance(nested_lst, (list, tuple)): + output = (nested_lst,), () + else: + flat_lst = [] + shapes = [] + for elt in nested_lst: + subelt, subshapes = flat_list_rec(elt, True) + flat_lst += subelt + shapes.append(subshapes) + + if len(shapes) == 0: + output = [], (0,) + elif all(subshapes == shapes[0] for subshapes in shapes): + output = flat_lst, (len(nested_lst),) + shapes[0] + else: + output = flat_lst, shapes + + if returns_shapes: + return output + else: + return output[0] + + +def unflat_list_rec(flat_lst: list, shapes: Union[list, tuple]) -> list: + """Unflat list to nested list with given shapes.""" + if isinstance(shapes, tuple): + if shapes == (): + return flat_lst[0] + else: + array = np.array(flat_lst, dtype=object) + array = array.reshape(*shapes) + array = array.tolist() + return array + else: + out = [] + idx = 0 + for shape_i in shapes: + num_elements = _prod_rec(shape_i) + unflatten = unflat_list_rec(flat_lst[idx : idx + num_elements], shape_i) + idx += num_elements + out.append(unflatten) + return out + + +def _prod_rec(x: Union[int, float, Iterable]) -> Union[int, float]: + if isinstance(x, (int, float)): + return x + elif isinstance(x, Iterable): + out = 1 + for xi in x: + out *= _prod_rec(xi) + return out + else: + raise TypeError( + f"Invalid argument type {type(x)=}. (expected int, float or iterable of int floats." + ) + + +def is_ascending(x: Iterable, strict: bool = False) -> bool: + x = list(x) + if len(x) <= 1: + return True + + if strict: + return all(xi < x[i + 1] for i, xi in enumerate(x[:-1])) + else: + return all(xi <= x[i + 1] for i, xi in enumerate(x[:-1])) + + +def is_descending(x: Iterable, strict: bool = False) -> bool: + x = list(x) + if len(x) <= 1: + return True + + if strict: + return all(xi > x[i + 1] for i, xi in enumerate(x[:-1])) + else: + return all(xi >= x[i + 1] for i, xi in enumerate(x[:-1])) + + +def sort_dict_with_patterns( + dic: dict[str, Any], + patterns: Iterable[str], +) -> dict[str, Any]: + patterns = list(patterns) + compl_patterns = list(map(re.compile, patterns)) + + def key_fn(key: str) -> int: + for i, pattern in enumerate(compl_patterns): + if re.match(pattern, key): + return i + return len(compl_patterns) + + dic = {k: dic[k] for k in sorted(dic.keys(), key=key_fn)} + return dic + + @overload def unzip(lst: Iterable[tuple[T]]) -> tuple[list[T]]: ... diff --git a/src/conette/utils/csv_utils.py b/src/conette/utils/csv_utils.py new file mode 100644 index 000000000..e739e86fd --- /dev/null +++ b/src/conette/utils/csv_utils.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv +import os.path as osp + +from pathlib import Path +from typing import Any, Iterable, Mapping, Union + +from conette.utils.collections import dict_list_to_list_dict, list_dict_to_dict_list + + +def load_csv_dict( + fpath: Union[str, Path], + has_fieldnames: bool = True, + cast: bool = False, +) -> dict[str, list[Any]]: + data = load_csv_list(fpath, has_fieldnames, cast) + data = list_dict_to_dict_list(data, None, True) + return data + + +def load_csv_list( + fpath: Union[str, Path], + has_fieldnames: bool = True, + cast: bool = False, +) -> list[dict[str, Any]]: + with open(fpath, "r") as file: + if has_fieldnames: + reader = csv.DictReader(file) + data = list(reader) + if len(data) == 0: + return [] + else: + reader = csv.reader(file) + data = list(reader) + if len(data) == 0: + return [] + default_fieldnames = list(map(str, range(len(data[0])))) + data = [dict(zip(default_fieldnames, data_i)) for data_i in data] + + if not cast: + return data + + outs = [] + for data_i in data: + outs_i = {} + for k, vs in data_i.items(): + try: + vs_new = [] + for v in vs: + v = eval(v) + vs_new.append(v) + outs_i[k] = vs_new + except (SyntaxError, NameError): + outs_i[k] = vs + outs.append(outs_i) + return outs + + +def save_csv_dict( + data: Mapping[str, Iterable[Any]], + fpath: Union[str, Path], + overwrite: bool = True, +) -> None: + data = dict(zip(data.keys(), map(list, data.values()))) + data_list = dict_list_to_list_dict(data) # type: ignore + save_csv_list(data_list, fpath, overwrite) + + +def save_csv_list( + data: Iterable[Mapping[str, Any]], + fpath: Union[str, Path], + overwrite: bool = True, +) -> None: + data = list(data) + if len(data) <= 0: + raise ValueError(f"Invalid argument {data=}. (found empty iterable)") + if not overwrite and osp.isfile(fpath): + raise FileExistsError("File already exists and argument overwrite is False.") + + with open(fpath, "w") as file: + fieldnames = list(data[0].keys()) + writer = csv.DictWriter(file, fieldnames) + writer.writeheader() + writer.writerows(data) diff --git a/src/conette/utils/custom_logger.py b/src/conette/utils/custom_logger.py new file mode 100644 index 000000000..3c887b49a --- /dev/null +++ b/src/conette/utils/custom_logger.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import copy +import datetime +import logging +import os.path as osp + +from argparse import Namespace +from typing import Any, Optional, Union + +from omegaconf import DictConfig +from pytorch_lightning.core.saving import save_hparams_to_yaml +from pytorch_lightning.loggers import TensorBoardLogger +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +class CustomTensorboardLogger(TensorBoardLogger): + """Custom Tensorboard Logger for saving hparams and metrics in tensorboard because we cannot save hparams and metrics several times in SummaryWriter. + + Note : hparams and metrics are saved only when 'save_and_close' is called. + """ + + FNAME_HPARAMS = "hparams.yaml" + FNAME_METRICS = "metrics.yaml" + FNAME_ENDFILE = "endfile.txt" + + def __init__( + self, + save_dir: str, + name: Optional[str] = "default", + version: Union[None, int, str] = None, + log_graph: bool = False, + default_hp_metric: bool = True, + prefix: str = "", + params: Union[dict[str, Any], DictConfig, None] = None, + verbose: bool = False, + log_to_text: bool = False, + **kwargs, + ) -> None: + super().__init__( + save_dir=save_dir, + name=name, + version=version, + log_graph=log_graph, + default_hp_metric=default_hp_metric, + prefix=prefix, + **kwargs, + ) + + params = _convert_dict_like_to_dict(params) + if default_hp_metric: + metrics = {"hp_metric": -1} + else: + metrics = {} + + self._all_hparams = params + self._all_metrics = metrics + self._verbose = verbose + self._log_to_text = log_to_text + + self._closed = False + + def __exit__(self) -> None: + if not self.is_closed(): + self.save_and_close() + + def log_hyperparams( + self, + params: Union[dict[str, Any], Namespace, None] = None, + metrics: Union[dict[str, Any], Namespace, None] = None, + ) -> None: + params = _convert_dict_like_to_dict(params) + metrics = _convert_dict_like_to_dict(metrics) + + none_metrics = {k: v for k, v in metrics.items() if v is None} + if len(none_metrics) > 0: + raise ValueError(f"Found None in metrics. (keys={none_metrics.keys()})") + + self._all_hparams.update(params) + self._all_metrics.update(metrics) + + def finalize(self, status: str) -> None: + # Called at the end of the training (after trainer.fit()) + self.experiment.flush() + + def update_files(self) -> None: + self._all_hparams = {k: _convert_value(v) for k, v in self._all_hparams.items()} + self._all_metrics = {k: _convert_value(v) for k, v in self._all_metrics.items()} + + self._all_hparams = dict(sorted(self._all_hparams.items())) + + fpath_hparams = osp.join(self.log_dir, self.FNAME_HPARAMS) + save_hparams_to_yaml(fpath_hparams, self._all_hparams) + + fpath_metrics = osp.join(self.log_dir, self.FNAME_METRICS) + save_hparams_to_yaml(fpath_metrics, self._all_metrics) + + def save_and_close(self) -> None: + if self.is_closed(): + raise RuntimeError("CustomTensorboardLogger cannot be closed twice.") + + self.update_files() + + if self._log_to_text: + prefix = f"{self.name}_{self.version}" + self.experiment.add_text(f"{prefix}/all_hparams", str(self._all_hparams)) + self.experiment.add_text(f"{prefix}/all_metrics", str(self._all_metrics)) + + for dic in (self._all_hparams, self._all_metrics): + for name, value in dic.items(): + self.experiment.add_text(f"{prefix}/{name}", str(value)) + + super().log_hyperparams(self._all_hparams, self._all_metrics) + self.experiment.flush() + + fpath_endfile = osp.join(self.log_dir, self.FNAME_ENDFILE) + with open(fpath_endfile, "w") as file: + now = datetime.datetime.now() + now = now.strftime("%Y:%m:%d_%H:%M:%S") + file.write(f"Process finished at {now}.\n") + + self._close() + + def _close(self) -> None: + if self._verbose: + pylog.debug( + f"Closing {self.__class__.__name__}... ({self.is_closed()=}; {self.expt_is_closed()=})" + ) + self.experiment.flush() + super().finalize("test") + self._closed = True + + def is_closed(self) -> bool: + return self._closed or self.expt_is_closed() + + def expt_is_closed(self) -> bool: + return self.experiment.all_writers is None + + @property + def hparams(self) -> dict: + return self._all_hparams + + @hparams.setter + def hparams(self, other: dict) -> None: + self._all_hparams = copy.deepcopy(other) + + @property + def metrics(self) -> dict: + return self._all_metrics + + +def _convert_value(v: Any) -> Any: + if isinstance(v, Tensor): + if v.nelement() == 1: + return v.item() + else: + return v.tolist() + elif isinstance(v, bool): + return str(v) + else: + return v + + +def _convert_dict_like_to_dict(dic: Union[dict, Namespace, DictConfig, None]) -> dict: + if dic is None: + return {} + elif isinstance(dic, Namespace): + return dic.__dict__ + elif isinstance(dic, DictConfig): + return dict(dic) + else: + return dic diff --git a/src/conette/utils/dcase.py b/src/conette/utils/dcase.py new file mode 100644 index 000000000..d5212e11c --- /dev/null +++ b/src/conette/utils/dcase.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv +import os.path as osp + +from typing import Iterable + + +DCASE_TASK6A_FIELDNAMES = ("file_name", "caption_predicted") +DCASE_TASK6B_TOP_N = 10 +DCASE_TASK6B_FIELDNAMES = ("caption",) + tuple( + f"fname_{i}" for i in range(1, DCASE_TASK6B_TOP_N + 1) +) + + +def export_to_dcase_task6a_csv( + csv_fpath: str, + audio_fnames: Iterable[str], + candidates: Iterable[str], + overwrite: bool = False, +) -> None: + """Export results to DCASE task6a CSV submission file. + + The CSV filename should be __task6a_submission__.csv + + The rules are defined in https://dcase.community/challenge2023/task-automated-audio-captioning#submission + + :param csv_fpath: The path to the new CSV file. + :param audio_fnames: The ordered list of audio filenames. + :param candidates: The ordered captions predicted by your AAC system corresponding to the audio filenames. + :param overwrite: + If the CSV already exists and overwrite is True, the function will replace it. + If the file already exists and overwrite is False, raises a FileExistsError. + It has no effect otherwise. + defaults to False. + """ + + audio_fnames = list(audio_fnames) + candidates = list(candidates) + + if not overwrite and osp.isfile(csv_fpath): + raise FileExistsError( + f"DCASE submission file {csv_fpath=} already exists. Please delete it or use argument overwrite=True." + ) + if len(audio_fnames) != len(candidates): + raise ValueError( + f"Invalid lengths for arguments audio_fnames and candidates. (found {len(audio_fnames)=} != {len(candidates)=})" + ) + + rows = [ + {DCASE_TASK6A_FIELDNAMES[0]: fname, DCASE_TASK6A_FIELDNAMES[1]: cand} + for fname, cand in zip(audio_fnames, candidates) + ] + with open(csv_fpath, "w") as file: + writer = csv.DictWriter(file, fieldnames=DCASE_TASK6A_FIELDNAMES) + writer.writeheader() + writer.writerows(rows) # type: ignore + + +def export_to_dcase_task6b_csv( + csv_fpath: str, + query_captions: Iterable[str], + predicted_fnames: Iterable[Iterable[str]], + overwrite: bool = False, +) -> None: + """Export results to DCASE task6b CSV submission file. + + The CSV filename should be __task6b_submission__output.csv + + The rules are defined in https://dcase.community/challenge2023/task-language-based-audio-retrieval#submission + + :param csv_fpath: The path to the new CSV file. + :param query_captions: The ordered list of queries. + :param predicted_fnames: The ordered list of top-10 filenames corresponding to the queries. + :param overwrite: + If the CSV already exists and overwrite is True, the function will replace it. + If the file already exists and overwrite is False, raises a FileExistsError. + It has no effect otherwise. + defaults to False. + """ + + query_captions = list(query_captions) + predicted_fnames = [list(fnames) for fnames in predicted_fnames] + + if not all(isinstance(query, str) for query in query_captions): + raise TypeError("Invalid argument type query_captions. (expected list[str])") + + if not all( + isinstance(fname, str) for fnames in predicted_fnames for fname in fnames + ): + raise TypeError( + "Invalid argument type predicted_fnames. (expected list[list[str]])" + ) + + if not overwrite and osp.isfile(csv_fpath): + raise FileExistsError( + f"DCASE submission file {csv_fpath=} already exists. Please delete it or use argument overwrite=True." + ) + if len(query_captions) != len(predicted_fnames): + raise ValueError( + f"Invalid lengths for arguments audio_fnames and candidates. (found {len(query_captions)=} != {len(predicted_fnames)=})" + ) + + invalid_lens = [ + query + for query, fnames in zip(query_captions, predicted_fnames) + if len(fnames) != DCASE_TASK6B_TOP_N + ] + if len(invalid_lens) > 0: + raise ValueError( + f"Invalid number of relevant audio filenames. (found {invalid_lens=} but expected only {DCASE_TASK6B_TOP_N} files per query)" + ) + + rows = [ + {DCASE_TASK6B_FIELDNAMES[0]: query} + | dict(zip(DCASE_TASK6B_FIELDNAMES[1:], fnames)) + for query, fnames in zip(query_captions, predicted_fnames) + ] + with open(csv_fpath, "w") as file: + writer = csv.DictWriter(file, fieldnames=DCASE_TASK6B_FIELDNAMES) + writer.writeheader() + writer.writerows(rows) # type: ignore diff --git a/src/conette/utils/disk_cache.py b/src/conette/utils/disk_cache.py new file mode 100644 index 000000000..4fe037fa6 --- /dev/null +++ b/src/conette/utils/disk_cache.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import inspect +import json +import logging +import os +import os.path as osp +import pickle +import shutil +import tempfile +import time + +from pathlib import Path +from typing import Any, Callable, Optional, TypeVar, Union + +import torch + +from conette.utils.csum import csum_any + + +T = TypeVar("T") +pylog = logging.getLogger(__name__) + + +class DiskCache: + CACHE_DNAME = "disk_cache" + + __global_instance: Optional["DiskCache"] = None + + def __init__( + self, + cache_path: Union[Path, str] = tempfile.gettempdir(), + force: bool = False, + filetype: str = "pickle", + verbose: int = 0, + fn: Optional[Callable] = None, + use_pylog: bool = True, + ) -> None: + cache_path = osp.expandvars(cache_path) + if not osp.isdir(cache_path): + raise RuntimeError( + f"Invalid cache directory {cache_path} for {self.__class__.__name__}." + ) + + if filetype not in ("json", "pickle"): + raise ValueError( + f"Invalid argument {filetype=}. (expected 'json' or 'pickle')" + ) + + super().__init__() + self._cache_path = cache_path + self._force = force + self._filetype = filetype + self._verbose = verbose + self._fn = fn + self._use_pylog = use_pylog + self._print = pylog.info if use_pylog else print + self._disable = False + + self._in_mode = "r" if filetype == "json" else "rb" + self._out_mode = "w" if filetype == "json" else "wb" + + self._n_hits = 0 + self._n_calls = 0 + + @classmethod + def get(cls, *args, **kwargs) -> "DiskCache": + if cls.__global_instance is None: + cls.__global_instance = DiskCache(*args, **kwargs) + return cls.__global_instance + + def clean(self, fn_or_name: Union[Callable, str]) -> int: + if isinstance(fn_or_name, str): + fn_name = fn_or_name + else: + fn_name = _get_callable_name(fn_or_name) + + target_path = osp.join(self._cache_path, self.CACHE_DNAME, fn_name) + + if not osp.exists(target_path): + pylog.warning(f"Target path does not exists. ({target_path})") + return 0 + + if not osp.isdir(target_path): + pylog.error( + f"Target path exists but it is not a directory. ({target_path})" + ) + return 0 + + n_items_removed = len(os.listdir(target_path)) + shutil.rmtree(target_path) + return n_items_removed + + def reset_stats(self) -> None: + self._n_hits = 0 + self._n_calls = 0 + + def wrap(self, fn: Callable) -> "DiskCache": + return DiskCache( + self._cache_path, + self._force, + self._filetype, + self._verbose, + fn, + self._use_pylog, + ) + + def unwrap(self) -> Optional[Callable]: + fn = self._fn + self._fn = None + return fn + + def __call__(self, *args, **kwargs) -> Any: + if self._fn is None: + raise RuntimeError( + f"Cannot call {self.__class__.__name__} without wrapping a callable object." + ) + return self.cache(self._fn, *args, **kwargs) + + def cache( + self, + fn: Callable[..., T], + *args, + force: Optional[bool] = None, + dc_verbose: Optional[int] = None, + csum_kwargs: Optional[dict[str, Any]] = None, + ignore_fn_csum: bool = False, + allow_compute: bool = True, + filetype: Optional[str] = None, + in_mode: Optional[str] = None, + out_mode: Optional[str] = None, + **kwargs, + ) -> T: + if dc_verbose is None: + dc_verbose = self._verbose + if force is None: + force = self._force + + fpath = self.get_fpath( + fn, + *args, + ignore_fn_csum=ignore_fn_csum, + csum_kwargs=csum_kwargs, + **kwargs, + ) + self._n_calls += 1 + + if not force: + outs, loaded = self.load(fpath, dc_verbose, filetype, in_mode) + else: + outs, loaded = None, False + + if loaded: + self._n_hits += 1 + else: + if not allow_compute: + raise ValueError( + f"Cannot compute outs for {_get_callable_name(fn)} with {allow_compute=}." + ) + + outs, duration = self._compute_outs( + fn, *args, dc_verbose=dc_verbose, **kwargs + ) + if not self._disable: + if dc_verbose >= 2: + self._print(f"Overwrite file {osp.basename(fpath)} with {force=}.") + self.dump(outs, fpath, duration, dc_verbose, filetype, out_mode) + + return outs # type: ignore + + @torch.no_grad() + def get_fpath( + self, + fn: Callable, + *args, + ignore_fn_csum: bool = False, + csum_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> str: + if ignore_fn_csum: + values = (args, kwargs) + else: + values = (fn, args, kwargs) + + default_csum_kwargs: dict[str, Any] = dict( + bytes_mode="adler32", + tensor_mode="sum_order", + iter_order=True, + accumulator=0, + unk_mode="pickle", + ) + if csum_kwargs is not None: + default_csum_kwargs |= csum_kwargs + + csum = csum_any(values, **default_csum_kwargs) + + fn_name = _get_callable_name(fn) + fname = f"{csum}.{self._filetype}" + + fpath = osp.join(self._cache_path, self.CACHE_DNAME, fn_name, fname) + return fpath + + def load( + self, + fpath: str, + dc_verbose: int = 0, + filetype: Optional[str] = None, + in_mode: Optional[str] = None, + ) -> tuple[Any, bool]: + if filetype is None: + filetype = self._filetype + if in_mode is None: + in_mode = self._in_mode + + try: + with open(fpath, in_mode) as file: + if filetype == "json": + outs = json.load(file)["data"] + elif filetype == "pickle": + outs = pickle.load(file)["data"] + else: + raise ValueError(f"Invalid value {filetype=}.") + + if dc_verbose >= 2: + self._print( + f"[HIT_] Outputs loaded from '{osp.basename(fpath)}'. (hits={self.get_n_hits()+1}/{self.get_n_calls()})" + ) + return outs, True + except (FileNotFoundError, json.JSONDecodeError, KeyError, EOFError): + return None, False + + def dump( + self, + outs: Any, + fpath: str, + duration: float = -1.0, + dc_verbose: int = 0, + filetype: Optional[str] = None, + out_mode: Optional[str] = None, + ) -> None: + if osp.isfile(fpath) and osp.getsize(fpath) == 0: + os.remove(fpath) + else: + parent = osp.dirname(fpath) + os.makedirs(parent, exist_ok=True) + + if filetype is None: + filetype = self._filetype + if out_mode is None: + out_mode = self._out_mode + + with open(fpath, out_mode) as file: + data = {"data": outs, "duration": duration} + if filetype == "json": + json.dump(data, file) + elif filetype == "pickle": + pickle.dump(data, file) # type: ignore + else: + raise ValueError(f"Invalid value {filetype=}.") + + if dc_verbose >= 2: + self._print(f"[MISS] Outputs dumped into '{osp.basename(fpath)}'.") + + def disable(self) -> None: + self._disable = True + + def is_disabled(self) -> bool: + return self._disable + + def set_forcing(self, force: bool) -> None: + self._force = force + + def is_forcing(self) -> bool: + return self._force + + def get_cache_path(self) -> str: + return self._cache_path + + def get_n_hits(self) -> int: + return self._n_hits + + def get_n_misses(self) -> int: + return self.get_n_calls() - self.get_n_hits() + + def get_n_calls(self) -> int: + return self._n_calls + + @property + def force(self) -> bool: + return self._force + + @force.setter + def force(self, force_: bool) -> None: + self._force = force_ + + def _compute_outs( + self, + fn: Callable, + *args, + dc_verbose: int = 0, + **kwargs, + ) -> tuple[Any, float]: + fn_name = _get_callable_name(fn) + if dc_verbose >= 1: + self._print(f"[MISS] Computing outs for fn '{fn_name}'...\r") + + start = time.perf_counter() + outs = fn(*args, **kwargs) + duration = time.perf_counter() - start + + if dc_verbose >= 1: + self._print( + f'[MISS] Outputs computed in {duration:.2f}s for file "{fn_name}".' + ) + return outs, duration + + +def disk_cache( + fn: Callable[..., T], + *args, + cache_path: str = "~/.cache", + force: bool = False, + dc_verbose: Optional[int] = None, + csum_kwargs: Optional[dict[str, Any]] = None, + ignore_fn_csum: bool = False, + allow_compute: bool = True, + **kwargs, +) -> T: + cache_path = osp.expandvars(cache_path) + cache_path = osp.expanduser(cache_path) + + global_cacher = DiskCache.get() + if global_cacher.get_cache_path() == cache_path: + cacher = global_cacher + else: + cacher = DiskCache(cache_path=cache_path) + + outs = cacher.cache( + fn, + *args, + force=force, + dc_verbose=dc_verbose, + csum_kwargs=csum_kwargs, + ignore_fn_csum=ignore_fn_csum, + allow_compute=allow_compute, + **kwargs, + ) + return outs + + +def _get_callable_name(fn: Callable) -> str: + if isinstance(fn, type) or inspect.isfunction(fn) or inspect.ismethod(fn): + fn_name = fn.__qualname__ + else: + fn_name = fn.__class__.__name__ + return fn_name diff --git a/src/conette/utils/hydra.py b/src/conette/utils/hydra.py new file mode 100644 index 000000000..ffba9d09e --- /dev/null +++ b/src/conette/utils/hydra.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os +import os.path as osp +import pickle + +from functools import cache +from logging import FileHandler +from pathlib import Path +from typing import Any, Optional, Union + +from hydra.core.hydra_config import HydraConfig +from hydra.types import RunMode +from omegaconf import DictConfig, OmegaConf +from omegaconf.errors import ConfigAttributeError + +from conette.utils.collections import flat_dict_of_dict +from conette.utils.yaml_utils import load_yaml + + +pylog = logging.getLogger(__name__) + + +class CustomFileHandler(FileHandler): + """FileHandler which builds intermediate directories. + + Used for export hydra logs to a file contained in a folder that does not exists yet at the start of the program. + """ + + def __init__( + self, + filename: str, + mode: str = "a", + encoding: Optional[str] = None, + delay: bool = True, + ) -> None: + parent_dpath = osp.dirname(filename) + if parent_dpath != "": + try: + os.makedirs(parent_dpath, exist_ok=True) + except PermissionError: + pass + super().__init__(filename, mode, encoding, delay) + + +# Public functions +def setup_resolvers() -> None: + """Prepare resolvers for hydra. + + This function should be called globally or before calling the function wrapped by hydra.main decorator. + """ + resolvers = { + "include_keys": include_keys_fn, + "get_tag": get_tag_fn, + "get_subtag": get_subtag_fn, + "prod": lambda x, y: x * y, + } + for name, resolver in resolvers.items(): + if not OmegaConf.has_resolver(name): + OmegaConf.register_new_resolver(name, resolver) + + +def include_keys_fn(prefix: str, _root_: DictConfig) -> list[str]: + """Special function used by sweeper to determine the job override_dirname by including keys instead of excluding keys. + + To use it, you must register this function as resolver aat the beginning of your program (BEFORE build your config): + ``` + >>> OmegaConf.register_new_resolver(name="include_keys", resolver=include_keys_fn) + ``` + And you can call this function in your config to override dirname. + + As an example, you can use it to include the search space of hydra: + ``` + sweep: + hydra: + job: + config: + override_dirname: + exclude_keys: "${include_keys: hydra.sweeper.search_space}" + ``` + """ + hydra_cfg = _load_hydra_cfg(_root_) + overrides_dic = _get_overrides_from_cfg(hydra_cfg) + included = OmegaConf.select(_root_, key=prefix).keys() + excluded = [value for value in overrides_dic.keys() if value not in included] + return excluded + + +def get_tag_fn(_root_: DictConfig) -> str: + tagk = _root_.tagk + if tagk == "auto": + raise ValueError( + "Cannot load 'multirun.yaml' automatically for tag interpolation." + ) + + join = "-" + tagv = _get_tag_or_subtag(_root_, tagk, "NOTAG", False) + + pretag: str = _root_.pretag + posttag: str = _root_.posttag + + if pretag != "": + if not pretag.endswith(join): + pretag = f"{pretag}{join}" + + if posttag != "": + if not posttag.startswith(join): + posttag = f"{join}{posttag}" + + tagv = f"{pretag}{tagv}{posttag}" + + return tagv + + +def get_subtag_fn(_root_: DictConfig) -> str: + subtagk = _root_.subtagk + + if subtagk == "auto": + hydra_cfg = _load_hydra_cfg(_root_) + overrides = _get_overrides_from_file(hydra_cfg) + subtagk = [k for k, v in overrides.items() if _is_sweep_value(v)] + if _root_.verbose >= 2: + pylog.debug(f"Auto-detect subtags: {subtagk}") + + subtagv = _get_tag_or_subtag(_root_, subtagk, "", True) + return subtagv + + +def get_none(*args, **kwargs) -> None: + """Returns None. + + Can be used for hydra instantiations with: + ``` + _target_: "conette.utils.hydra.get_none" + ``` + """ + return None + + +def get_pickle( + fpath: Union[str, Path], +) -> Any: + """Returns the pickled object from file. + + Can be used for hydra instantiations with: + ``` + _target_: "conette.utils.hydra.get_pickle" + fpath: "/path/to/file" + ``` + + :param fpath: The filepath to the pickled object. + :returns: The pickled object. + """ + if not isinstance(fpath, (str, Path)): + raise TypeError(f"Invalid transform with pickle {fpath=}. (not a str or Path)") + if not osp.isfile(fpath): + raise FileNotFoundError(f"Invalid transform with pickle {fpath=}. (not a file)") + + with open(fpath, "rb") as file: + data = pickle.load(file) + return data + + +@cache +def get_subrun_path() -> str: + hydra_cfg = HydraConfig.get() + return osp.join(hydra_cfg.sweep.dir, hydra_cfg.sweep.subdir) + + +# Private functions +def _load_hydra_cfg(_root_: DictConfig) -> DictConfig: + try: + hydra_cfg = _root_.hydra + return hydra_cfg + except ConfigAttributeError: + pass + + try: + hydra_cfg = HydraConfig.get() + return hydra_cfg # type: ignore + except ValueError as err: + pylog.error( + "Cannot get hydra cfg from root or from global HydraConfig instance." + ) + raise err + + +def _get_tag_or_subtag( + _root_: DictConfig, + keys: Union[str, list[str]], + default: str, + accept_sweep_values: bool, +) -> str: + if isinstance(keys, str): + keys = [keys] + + hydra_cfg = _load_hydra_cfg(_root_) + overrides = _get_overrides_from_cfg(hydra_cfg) + overrides_clean = { + k.replace("/", ".").split(".")[-1]: v for k, v in overrides.items() + } + choices = hydra_cfg.runtime.choices + choices_clean = {k.replace("/", ".").split(".")[-1]: v for k, v in choices.items()} + + options = {} + for key in keys: + if key in overrides: + options[key] = overrides[key] + elif key in overrides_clean: + options[key] = overrides_clean[key] + elif key in choices: + options[key] = choices[key] + elif key in choices_clean: + options[key] = choices_clean[key] + else: + value = OmegaConf.select(_root_, key, default="NOTFOUND") + if value == "NOTFOUND": + dic = OmegaConf.to_container(_root_) + flatten = flat_dict_of_dict(dic) # type: ignore + matches = [k for k in flatten.keys() if k.endswith(key)] + + if len(matches) == 1: + value = OmegaConf.select(_root_, matches[0], default="NOTFOUND") + if value == "NOTFOUND": + pylog.error( + f"INTERNAL ERROR: Cannot find {matches[0]=} in config." + ) + continue + + elif len(matches) == 0: + pylog.warning(f"Cannot find {key=} for tag.") + continue + else: # > 1 + pylog.warning( + f"Found multiple candidates with {key=} for tag. ({matches=})" + ) + continue + + if not isinstance(value, (int, float, str)): + pylog.warning( + f"Ignore {key=} for tag. (expected type in (int, float, str))" + ) + continue + + options[key] = value + + options_clean = {k.replace("/", ".").split(".")[-1]: v for k, v in options.items()} + + if len(options) != len(options_clean): + raise ValueError( + f"Found duplicated option name after dot. (found {tuple(options.keys())} != {tuple(options_clean.keys())})" + ) + + if not accept_sweep_values: + sweep_values = {k: v for k, v in options_clean.items() if _is_sweep_value(v)} + if len(sweep_values) > 0: + raise ValueError( + f"Invalid sweep values for main tag. (sweep keys: {tuple(sweep_values.keys())})" + ) + + tag = "-".join(f"{k}_{v}" for k, v in options_clean.items()) + tag = tag.replace(" ", "") + if tag == "": + tag = default + else: + tag = "-" + tag + + # Clean tag + replaces = { + "=": "_", + ",": "_", + " ": "_", + "[": "", + "]": "", + } + for p, v in replaces.items(): + tag = tag.replace(p, v) + + return tag + + +def _get_overrides_from_cfg(hydra_cfg: DictConfig) -> dict[str, Any]: + overrides = hydra_cfg.overrides.task + overrides_dic = { + kv.split("=")[0].removeprefix("+"): kv.split("=")[1] for kv in overrides + } + + output = {} + for k, v in overrides_dic.items(): + if any(s in v for s in (".", "e", "E")): + try: + v = str(float(v)) + except ValueError: + pass + output[k] = v + + return output + + +def _get_overrides_from_file(hydra_cfg: DictConfig) -> dict[str, Any]: + if hydra_cfg.mode != RunMode.MULTIRUN: + return {} + + multirun_fpath = osp.join(hydra_cfg.sweep.dir, "multirun.yaml") + if not osp.isfile(multirun_fpath): + pylog.error( + f"Cannot find automatically 'multirun.yaml' file in directory '{osp.dirname(multirun_fpath)}'." + ) + return {} + + data = load_yaml(multirun_fpath) + overrides: list[str] = data.get("hydra", {}).get("overrides", {}).get("task", []) + overrides_dic = { + kv.split("=")[0].removeprefix("+"): kv.split("=")[1] for kv in overrides + } + return overrides_dic + + +def _is_sweep_value(v: str) -> bool: + """Returns true if the value is a hydra sweep argument value. + + >>> _is_sweep_value("1,2") + ... True + >>> _is_sweep_value("[1,2]") + ... False + >>> _is_sweep_value("a,b,c") + ... True + >>> _is_sweep_value("something") + ... False + """ + return ( + isinstance(v, str) + and "," in v + and not v.startswith("[") + and not v.endswith("]") + ) + + +def load_overrides(fpath: str) -> dict[str, Any]: + overrides: list[str] = load_yaml(fpath) + overrides_dic = {} + + for override in overrides: + idx = override.find("=") + if idx == -1: + raise RuntimeError(f"Cannot find character '=' in overrides. ({override=})") + + name = override[:idx].removeprefix("++").removeprefix("+") + value = override[idx + 1 :] + overrides_dic[name] = value + + return overrides_dic diff --git a/src/conette/utils/misc.py b/src/conette/utils/misc.py new file mode 100644 index 000000000..c1fa904f4 --- /dev/null +++ b/src/conette/utils/misc.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import datetime +import inspect +import logging +import os +import os.path as osp +import re +import shutil +import subprocess +import zipfile + +from pathlib import Path +from subprocess import CalledProcessError +from typing import ( + Any, + Callable, + Iterable, + Optional, + TypeVar, + Union, +) +from zipfile import ZipFile + +import torch +import tqdm + +from pytorch_lightning.utilities.seed import seed_everything + + +pylog = logging.getLogger(__name__) +T = TypeVar("T") + + +def get_none() -> None: + # Returns None. Can be used for hydra instantiations. + return None + + +def get_datetime(fmt: str = "%Y.%m.%d-%H.%M.%S") -> str: + now = datetime.datetime.now() + return now.strftime(fmt) + + +def reset_seed(seed: Optional[int]) -> Optional[int]: + """Reset the seed of following packages for reproductibility : + - random + - numpy + - torch + - torch.cuda + + Also set deterministic behaviour for cudnn backend. + + :param seed: The seed to set. If None, this function does nothing. + """ + if seed is not None and not isinstance(seed, int): + raise TypeError( + f"Invalid argument type {type(seed)=}. (expected NoneType or int)" + ) + + if seed is None: + return seed + + seed = seed_everything(seed, workers=True) + torch.backends.cudnn.benchmark = False # type: ignore + torch.backends.cudnn.deterministic = True # type: ignore + return seed + + +def save_conda_env(fpath: str, conda_path: str = "conda", verbose: int = 1) -> bool: + try: + cmd = [conda_path, "env", "export", "-f", fpath] + _output = subprocess.check_output(cmd) + return True + except (CalledProcessError, PermissionError, FileNotFoundError) as err: + if verbose >= 0: + pylog.warning(f"Cannot save conda env in {fpath}. ({err=})") + return False + + +def save_micromamba_env( + fpath: str, micromamba_path: str = "micromamba", verbose: int = 1 +) -> bool: + try: + cmd = [micromamba_path, "env", "export"] + output = subprocess.check_output(cmd) + output = output.decode() + with open(fpath, "w") as file: + file.writelines([output]) + return True + except (CalledProcessError, PermissionError, FileNotFoundError) as err: + if verbose >= 0: + pylog.warning(f"Cannot save micromamba env in {fpath}. ({err=})") + return False + + +def get_current_git_hash( + cwd: str = osp.dirname(__file__), + default: T = "UNKNOWN", +) -> Union[str, T]: + """ + Return the current git hash in the current directory. + + :returns: The git hash. If an error occurs, returns 'UNKNOWN'. + """ + try: + git_hash = subprocess.check_output("git describe --always".split(" "), cwd=cwd) + git_hash = git_hash.decode("UTF-8").replace("\n", "") + return git_hash + except (CalledProcessError, PermissionError) as err: + pylog.warning( + f"Cannot get current git hash from {cwd=}. (error message: '{err}')" + ) + return default + + +def get_tags_version(cwd: str = osp.dirname(__file__)) -> str: + """ + {LAST_TAG}-{NB_COMMIT_AFTER_LAST_TAG}-g{LAST_COMMIT_HASH} + Example : v0.1.1-119-g40317c7 + + :returns: The tag version with the git hash. + """ + try: + git_hash = subprocess.check_output("git describe --tags".split(" "), cwd=cwd) + git_hash = git_hash.decode("UTF-8").replace("\n", "") + return git_hash + except (subprocess.CalledProcessError, PermissionError): + return "UNKNOWN" + + +def get_obj_clsname(obj: Any) -> str: + """Returns the full class name of an object.""" + class_ = obj.__class__ + module = class_.__module__ + if module == "builtins": + return class_.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + class_.__qualname__ + + +def save_code_to_zip( + logdir: str, + zip_fname: str = "source_code.zip", + compression: int = zipfile.ZIP_LZMA, + compresslevel: int = 1, + verbose: int = 0, +) -> None: + logdir = osp.expandvars(logdir) + zip_fpath = osp.join(logdir, zip_fname) + code_root_dpath = Path(__file__).parent.parent.parent + + suffixes_dnames = ( + ".ipynb_checkpoints", + "old", + "ign", + "__pycache__", + "/logs", + "/data", + ".egg-info", + ) + + include_fnames = [ + r".*\." + ext + for ext in ("py", "yaml", "rst", "md", "sh", "txt", "cfg", "ini", "in") + ] + exclude_fnames = (r".*(\.ign|\.old|_ign|_old|\.egg-info).*",) + + include_fnames = list(map(re.compile, include_fnames)) + exclude_fnames = list(map(re.compile, exclude_fnames)) + + tgt_fpaths = [] + for root, directories, files in tqdm.tqdm( + os.walk(code_root_dpath), + disable=verbose <= 1, + desc="Searching files to save...", + ): + if any(root.endswith(suffix) for suffix in suffixes_dnames): + directories[:] = [] + continue + tgt_fnames = [ + fname + for fname in files + if any(re.match(p, fname) for p in include_fnames) + and all(not re.match(p, fname) for p in exclude_fnames) + ] + if verbose >= 2 and len(tgt_fnames) > 0: + pylog.debug( + f"{root=} with {len(tgt_fnames)} python files. (ex={tgt_fnames[0]})" + ) + tgt_fpaths += [osp.join(root, fname) for fname in tgt_fnames] + + with ZipFile( + zip_fpath, "w", compression=compression, compresslevel=compresslevel + ) as zfile: + for fpath in tqdm.tqdm( + tgt_fpaths, disable=verbose <= 1, desc=f"Writing {len(tgt_fpaths)} files..." + ): + zfile.write(fpath, arcname=osp.relpath(fpath, code_root_dpath)) + + +def copy_slurm_logs( + fpaths: Iterable[Optional[str]], + subrun_dpath: Optional[str], +) -> None: + if subrun_dpath is None: + return None + if "SLURM_JOB_ID" not in os.environ: + return None + + subrun_dpath = osp.expandvars(subrun_dpath) + fpaths = [fpath for fpath in fpaths if fpath is not None] + + job_id = os.environ["SLURM_JOB_ID"] + replaces = { + "%j": job_id, + "%A": job_id, + } + for pattern, value in replaces.items(): + fpaths = [fpath.replace(pattern, value) for fpath in fpaths] + fpaths = [fpath for fpath in fpaths if osp.isfile(fpath)] + + if len(fpaths) == 0: + return None + + tgt_dpath = osp.join(subrun_dpath, "logs") + os.makedirs(tgt_dpath, exist_ok=True) + + for fpath in fpaths: + fname = osp.basename(fpath) + tgt_fpath = osp.join(tgt_dpath, fname) + shutil.copyfile(fpath, tgt_fpath) + + +def pass_filter( + name: str, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, +) -> bool: + """Returns True if name in include set and not in exclude set.""" + if include is not None and exclude is not None: + return (name in include) and (name not in exclude) + if include is not None: + return name in include + elif exclude is not None: + return name not in exclude + else: + return True + + +def compose(*fns: Callable[[Any], Any]) -> Callable[[Any], Any]: + def compose_impl(x): + for fn in fns: + x = fn(x) + return x + + return compose_impl diff --git a/src/conette/utils/yaml_utils.py b/src/conette/utils/yaml_utils.py new file mode 100644 index 000000000..9ccb4edec --- /dev/null +++ b/src/conette/utils/yaml_utils.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from pathlib import Path +from typing import Any, Union + +import yaml + + +def load_yaml(fpath: Union[str, Path]) -> Any: + with open(fpath, "r") as file: + data = yaml.safe_load(file) + return data + + +def save_yaml(data: Any, fpath: Union[str, Path]) -> None: + with open(fpath, "w") as file: + yaml.dump(data, file) + return data diff --git a/src/conf/__init__.py b/src/conf/__init__.py new file mode 100644 index 000000000..400f5b3cb --- /dev/null +++ b/src/conf/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Note: This file is needed to make the entry points work. diff --git a/src/conf/audio_t/none.yaml b/src/conf/audio_t/none.yaml new file mode 100644 index 000000000..d1a69c926 --- /dev/null +++ b/src/conf/audio_t/none.yaml @@ -0,0 +1,3 @@ +# @package audio_t + +_target_: "conette.utils.hydra.get_none" diff --git a/src/conf/audio_t/resample_mean_cnn10.yaml b/src/conf/audio_t/resample_mean_cnn10.yaml new file mode 100644 index 000000000..6c9820aa7 --- /dev/null +++ b/src/conf/audio_t/resample_mean_cnn10.yaml @@ -0,0 +1,12 @@ +# @package audio_t + +_target_: "conette.transforms.get.get_resample_mean_cnn10" + +src_sr: ??? +tgt_sr: 32000 +mean_dim: 0 +window_size: 1024 +hop_size: 320 +mel_bins: 64 +device: "auto" +transpose_frame_embs: true diff --git a/src/conf/audio_t/resample_mean_cnn14.yaml b/src/conf/audio_t/resample_mean_cnn14.yaml new file mode 100644 index 000000000..9996dd22c --- /dev/null +++ b/src/conf/audio_t/resample_mean_cnn14.yaml @@ -0,0 +1,15 @@ +# @package audio_t + +_target_: "conette.transforms.get.get_resample_mean_cnn14" + +src_sr: ??? +tgt_sr: 32000 +mean_dim: 0 +window_size: 1024 +hop_size: 320 +mel_bins: 64 +device: "auto" +transpose_frame_embs: true +only_frame_embs: false + +pretrain_path: ??? diff --git a/src/conf/audio_t/resample_mean_cnn14_att.yaml b/src/conf/audio_t/resample_mean_cnn14_att.yaml new file mode 100644 index 000000000..35491e629 --- /dev/null +++ b/src/conf/audio_t/resample_mean_cnn14_att.yaml @@ -0,0 +1,15 @@ +# @package audio_t + +_target_: "conette.transforms.get.get_resample_mean_cnn14_att" + +src_sr: ??? +tgt_sr: 32000 +mean_dim: 0 +window_size: 1024 +hop_size: 320 +mel_bins: 64 +device: "auto" +transpose_frame_embs: true +only_frame_embs: false + +pretrain_path: ??? diff --git a/src/conf/audio_t/resample_mean_convnext.yaml b/src/conf/audio_t/resample_mean_convnext.yaml new file mode 100644 index 000000000..13f450192 --- /dev/null +++ b/src/conf/audio_t/resample_mean_convnext.yaml @@ -0,0 +1,12 @@ +# @package audio_t + +_target_: "conette.transforms.get.get_resample_mean_convnext" + +src_sr: ??? +tgt_sr: 32000 +mean_dim: 0 + +device: "auto" +transpose_frame_embs: true +only_frame_embs: false +pretrain_path: ??? diff --git a/src/conf/audio_t/spec_aug_emb.yaml b/src/conf/audio_t/spec_aug_emb.yaml new file mode 100644 index 000000000..de9a302b2 --- /dev/null +++ b/src/conf/audio_t/spec_aug_emb.yaml @@ -0,0 +1,13 @@ +# @package audio_t + +_target_: "conette.transforms.audio.spec_aug.SpecAugment" + +# Note: input must be (bsize, n_channels, time_steps, freq_steps) or (n_channels, time_steps, freq_steps) + +time_max_width: 2 +time_stripes_num: 2 +freq_max_width: 2 +freq_stripes_num: 2 +time_dim: -2 +freq_dim: -1 +p: 1.0 diff --git a/src/conf/audio_t/spec_aug_ratio_emb.yaml b/src/conf/audio_t/spec_aug_ratio_emb.yaml new file mode 100644 index 000000000..8eba82b9c --- /dev/null +++ b/src/conf/audio_t/spec_aug_ratio_emb.yaml @@ -0,0 +1,13 @@ +# @package audio_t + +_target_: "conette.transforms.audio.spec_aug.SpecAugmentRatio" + +# Note: input must be (bsize, n_channels, time_steps, freq_steps) or (n_channels, time_steps, freq_steps) + +time_ratios: [0.0, 0.1] +time_stripes_num: 2 +freq_ratios: [0.0, 0.1] +freq_stripes_num: 2 +time_dim: -2 +freq_dim: -1 +p: 1.0 diff --git a/src/conf/ckpts/fense.yaml b/src/conf/ckpts/fense.yaml new file mode 100644 index 000000000..4804ea073 --- /dev/null +++ b/src/conf/ckpts/fense.yaml @@ -0,0 +1,15 @@ +# @package ckpts + +- _target_: "conette.callbacks.custom_ckpt.CustomModelCheckpoint" + + dirpath: "${hydra:sweep.dir}/${hydra:sweep.subdir}/checkpoints" + save_last: false + save_top_k: 1 + monitor: "val/fense" + mode: "max" + verbose: ${verbose} + filename: "{epoch:03d}-{step:06d}-mode_${ckpts.0.mode}-{${ckpts.0.monitor}:.4f}" # default: "{epoch}-{step}" + + log_best_score: true + save_after_epoch: null + create_symlink: true diff --git a/src/conf/ckpts/loss.yaml b/src/conf/ckpts/loss.yaml new file mode 100644 index 000000000..680fec97a --- /dev/null +++ b/src/conf/ckpts/loss.yaml @@ -0,0 +1,15 @@ +# @package ckpts + +- _target_: "aac.callbacks.custom_ckpt.CustomModelCheckpoint" + + dirpath: "${hydra:sweep.dir}/${hydra:sweep.subdir}/checkpoints" + save_last: false + save_top_k: 1 + monitor: "val/loss" + mode: "min" + verbose: ${verbose} + filename: "{epoch:03d}-{step:06d}-mode_${ckpts.0.mode}-{${ckpts.0.monitor}:.4f}" # default: "{epoch}-{step}" + + log_best_score: true + save_after_epoch: null + create_symlink: true diff --git a/src/conf/data/audiocaps.yaml b/src/conf/data/audiocaps.yaml new file mode 100644 index 000000000..4f1b15f98 --- /dev/null +++ b/src/conf/data/audiocaps.yaml @@ -0,0 +1,14 @@ +# @package data + +# Common params +root: "${path.data}" +bsize: 512 +n_workers: ${slurm.cpus_per_task} +pin_memory: true +verbose: ${verbose} + +# Other params +subsets: null +with_tags: true +download: false +audiocaps_caps_fix_fpath: null diff --git a/src/conf/data/clotho.yaml b/src/conf/data/clotho.yaml new file mode 100644 index 000000000..47b089aae --- /dev/null +++ b/src/conf/data/clotho.yaml @@ -0,0 +1,14 @@ +# @package data + +# Common params +root: "${path.data}" +bsize: 512 +n_workers: ${slurm.cpus_per_task} +pin_memory: true +verbose: ${verbose} + +# Other params +subsets: null +version: "v2.1" +download: false +clean_archives: false diff --git a/src/conf/data/hdf.yaml b/src/conf/data/hdf.yaml new file mode 100644 index 000000000..bfa06951a --- /dev/null +++ b/src/conf/data/hdf.yaml @@ -0,0 +1,14 @@ +# @package data + +# Common params +root: "${path.data}" +bsize: 512 +n_workers: ${slurm.cpus_per_task} +pin_memory: true +verbose: ${verbose} + +# Other params +name: ??? +subsets: ??? +hdf_suffix: ??? +include_columns: null diff --git a/src/conf/data/macs.yaml b/src/conf/data/macs.yaml new file mode 100644 index 000000000..4a22c0b7e --- /dev/null +++ b/src/conf/data/macs.yaml @@ -0,0 +1,14 @@ +# @package data + +# Common params +root: "${path.data}" +bsize: 512 +n_workers: ${slurm.cpus_per_task} +pin_memory: true +verbose: ${verbose} + +# Other params +subsets: null +download: false +tags_to_str: true +clean_archives: false diff --git a/src/conf/data/none.yaml b/src/conf/data/none.yaml new file mode 100644 index 000000000..2a0aea6ed --- /dev/null +++ b/src/conf/data/none.yaml @@ -0,0 +1,8 @@ +# @package data + +# Common params +root: "." +bsize: 1 +n_workers: ${slurm.cpus_per_task} +pin_memory: false +verbose: ${verbose} diff --git a/src/conf/data/wavcaps.yaml b/src/conf/data/wavcaps.yaml new file mode 100644 index 000000000..59e63d58c --- /dev/null +++ b/src/conf/data/wavcaps.yaml @@ -0,0 +1,13 @@ +# @package data + +# Common params +root: "${path.data}" +bsize: 512 +n_workers: ${slurm.cpus_per_task} +pin_memory: true +verbose: ${verbose} + +# Other params +subsets: [as,bbc,sb] +download: false +hf_cache_dir: null diff --git a/src/conf/dm/hdf.yaml b/src/conf/dm/hdf.yaml new file mode 100644 index 000000000..66700aa38 --- /dev/null +++ b/src/conf/dm/hdf.yaml @@ -0,0 +1,24 @@ +# @package dm + +_target_: "conette.datamodules.hdf.HDFDataModule" + +root: "${path.data}" +bsize: 512 +n_workers: ${slurm.cpus_per_task} +pin_memory: true +train_drop_last: false +verbose: ${verbose} + +train_cols: [audio, audio_shape, captions] +val_cols: [audio, audio_shape, captions] +test_cols: [audio, audio_shape, captions, dataset, subset, fname, index] + +train_hdfs: ??? +val_hdfs: ??? +test_hdfs: ??? +predict_hdfs: [] +audio_padding: "batch" +main_hdf_duplicate: null +main_hdf_min: null +main_hdf_balanced: null +n_added_data: null diff --git a/src/conf/evaluator/aac.yaml b/src/conf/evaluator/aac.yaml new file mode 100644 index 000000000..3c55d64b9 --- /dev/null +++ b/src/conf/evaluator/aac.yaml @@ -0,0 +1,14 @@ +# @package evaluator + +_target_: "conette.callbacks.aac_evaluator.AACEvaluator" + +subrun_path: "${hydra:sweep.dir}/${hydra:sweep.subdir}" +cache_path: "${path.cache}" +java_path: "${path.java}" +tmp_path: "${path.tmp}" +verbose: ${verbose} +debug: ${debug} +save_to_csv: ${save} +save_dcase_csv_file: ${save} +metric_device: null +cpus: ${slurm.cpus_per_task} diff --git a/src/conf/evaluator/none.yaml b/src/conf/evaluator/none.yaml new file mode 100644 index 000000000..1916b287d --- /dev/null +++ b/src/conf/evaluator/none.yaml @@ -0,0 +1,3 @@ +# @package evaluator + +_target_: "conette.utils.hydra.get_none" diff --git a/src/conf/expt/audiocaps_cnext_bl_v6.yaml b/src/conf/expt/audiocaps_cnext_bl_v6.yaml new file mode 100644 index 000000000..e46ce10de --- /dev/null +++ b/src/conf/expt/audiocaps_cnext_bl_v6.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - hp_audiocaps_v2 + +dm: + train_hdfs: + - audiocaps_train_v6_resample_mean_convnext_ident_bl.hdf + val_hdfs: + - audiocaps_val_resample_mean_convnext_ident_bl.hdf + test_hdfs: + - clotho_val_resample_mean_convnext_ident_bl.hdf + - clotho_eval_resample_mean_convnext_ident_bl.hdf + - audiocaps_val_resample_mean_convnext_ident_bl.hdf + - audiocaps_test_resample_mean_convnext_ident_bl.hdf + +pl: + proj_name: "lin768" diff --git a/src/conf/expt/audiocaps_cnext_nobl_v6.yaml b/src/conf/expt/audiocaps_cnext_nobl_v6.yaml new file mode 100644 index 000000000..c7fa8dad1 --- /dev/null +++ b/src/conf/expt/audiocaps_cnext_nobl_v6.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - hp_audiocaps_v2 + +dm: + train_hdfs: + - audiocaps_train_v6_resample_mean_convnext_ident_nobl.hdf + val_hdfs: + - audiocaps_val_resample_mean_convnext_ident_nobl.hdf + test_hdfs: + - clotho_val_resample_mean_convnext_ident_nobl.hdf + - clotho_eval_resample_mean_convnext_ident_nobl.hdf + - audiocaps_val_resample_mean_convnext_ident_nobl.hdf + - audiocaps_test_resample_mean_convnext_ident_nobl.hdf + +pl: + proj_name: "lin768" diff --git a/src/conf/expt/audiocaps_cnn14_bl_v6.yaml b/src/conf/expt/audiocaps_cnn14_bl_v6.yaml new file mode 100644 index 000000000..7c44afe37 --- /dev/null +++ b/src/conf/expt/audiocaps_cnn14_bl_v6.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - hp_audiocaps_v2 + +dm: + train_hdfs: + - audiocaps_train_v6_resample_mean_cnn14_ident_bl.hdf + val_hdfs: + - audiocaps_val_resample_mean_cnn14_ident_bl.hdf + test_hdfs: + - clotho_val_resample_mean_cnn14_ident_bl.hdf + - clotho_eval_resample_mean_cnn14_ident_bl.hdf + - audiocaps_val_resample_mean_cnn14_ident_bl.hdf + - audiocaps_test_resample_mean_cnn14_ident_bl.hdf + +pl: + proj_name: "lin2048" diff --git a/src/conf/expt/audiocaps_cnn14_nobl_v6.yaml b/src/conf/expt/audiocaps_cnn14_nobl_v6.yaml new file mode 100644 index 000000000..e342167f8 --- /dev/null +++ b/src/conf/expt/audiocaps_cnn14_nobl_v6.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - hp_audiocaps_v2 + +dm: + train_hdfs: + - audiocaps_train_v6_resample_mean_cnn14_ident_nobl.hdf + val_hdfs: + - audiocaps_val_resample_mean_cnn14_ident_nobl.hdf + test_hdfs: + - clotho_val_resample_mean_cnn14_ident_nobl.hdf + - clotho_eval_resample_mean_cnn14_ident_nobl.hdf + - audiocaps_val_resample_mean_cnn14_ident_nobl.hdf + - audiocaps_test_resample_mean_cnn14_ident_nobl.hdf + +pl: + proj_name: "lin2048" diff --git a/src/conf/expt/audiocaps_cnn14_pann_v6.yaml b/src/conf/expt/audiocaps_cnn14_pann_v6.yaml new file mode 100644 index 000000000..0090c68c0 --- /dev/null +++ b/src/conf/expt/audiocaps_cnn14_pann_v6.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - hp_audiocaps_v2 + +dm: + train_hdfs: + - audiocaps_train_v6_resample_mean_cnn14_ident_pann.hdf + val_hdfs: + - audiocaps_val_resample_mean_cnn14_ident_pann.hdf + test_hdfs: + - clotho_val_resample_mean_cnn14_ident_pann.hdf + - clotho_eval_resample_mean_cnn14_ident_pann.hdf + - audiocaps_val_resample_mean_cnn14_ident_pann.hdf + - audiocaps_test_resample_mean_cnn14_ident_pann.hdf + +pl: + proj_name: "lin2048" diff --git a/src/conf/expt/camw_cnext_bl_for_a.yaml b/src/conf/expt/camw_cnext_bl_for_a.yaml new file mode 100644 index 000000000..e32cc0b75 --- /dev/null +++ b/src/conf/expt/camw_cnext_bl_for_a.yaml @@ -0,0 +1,32 @@ +# @package _global_ + +defaults: + - hp_audiocaps_v2 # note: could be changed if target dset change + +dm: + train_hdfs: + - clotho_dev_resample_mean_convnext_ident_bl.hdf + - audiocaps_train_v6_resample_mean_convnext_ident_bl.hdf + - macs_full_resample_mean_convnext_ident_bl.hdf + - wavcaps_as_noac_resample_mean_convnext_ident_bl.hdf + - wavcaps_bbc_resample_mean_convnext_ident_bl.hdf + - wavcaps_sb_resample_mean_convnext_ident_bl.hdf + - wavcaps_fsd_nocl_resample_mean_convnext_ident_bl.hdf + val_hdfs: + - audiocaps_val_resample_mean_convnext_ident_bl.hdf # note: could be changed if target dset change + test_hdfs: + - audiocaps_val_resample_mean_convnext_ident_bl.hdf + - audiocaps_test_resample_mean_convnext_ident_bl.hdf + - clotho_val_resample_mean_convnext_ident_bl.hdf + - clotho_eval_resample_mean_convnext_ident_bl.hdf + predict_hdfs: + - clotho_test_resample_mean_convnext_ident_bl.hdf + - clotho_analysis_resample_mean_convnext_ident_bl.hdf + + main_hdf_min: audiocaps_train_v6_resample_mean_convnext_ident_bl.hdf # note: could be changed if target dset change + +pl: + proj_name: "lin768" + +trainer: + reload_dataloaders_every_n_epochs: 1 # /!\ to be used with main_hdf_min to sample different other data across epochs diff --git a/src/conf/expt/camw_cnext_bl_for_ac.yaml b/src/conf/expt/camw_cnext_bl_for_ac.yaml new file mode 100644 index 000000000..a811cbef0 --- /dev/null +++ b/src/conf/expt/camw_cnext_bl_for_ac.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +defaults: + - hp_audiocaps_clotho_v2 # note: could be changed if target dset change + +dm: + train_hdfs: + - clotho_dev_resample_mean_convnext_ident_bl.hdf + - audiocaps_train_v6_resample_mean_convnext_ident_bl.hdf + - macs_full_resample_mean_convnext_ident_bl.hdf + - wavcaps_as_noac_resample_mean_convnext_ident_bl.hdf + - wavcaps_bbc_resample_mean_convnext_ident_bl.hdf + - wavcaps_sb_resample_mean_convnext_ident_bl.hdf + - wavcaps_fsd_nocl_resample_mean_convnext_ident_bl.hdf + val_hdfs: + - audiocaps_val_resample_mean_convnext_ident_bl.hdf + - clotho_val_resample_mean_convnext_ident_bl.hdf + test_hdfs: + - audiocaps_val_resample_mean_convnext_ident_bl.hdf + - audiocaps_test_resample_mean_convnext_ident_bl.hdf + - clotho_val_resample_mean_convnext_ident_bl.hdf + - clotho_eval_resample_mean_convnext_ident_bl.hdf + predict_hdfs: + - clotho_test_resample_mean_convnext_ident_bl.hdf + - clotho_analysis_resample_mean_convnext_ident_bl.hdf + + main_hdf_balanced: [audiocaps_train_v6_resample_mean_convnext_ident_bl.hdf, clotho_dev_resample_mean_convnext_ident_bl.hdf] + +pl: + proj_name: "lin768" + +trainer: + reload_dataloaders_every_n_epochs: 1 # /!\ to be used with main_hdf_min to sample different other data across epochs diff --git a/src/conf/expt/camw_cnext_bl_for_c.yaml b/src/conf/expt/camw_cnext_bl_for_c.yaml new file mode 100644 index 000000000..573e264ad --- /dev/null +++ b/src/conf/expt/camw_cnext_bl_for_c.yaml @@ -0,0 +1,32 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 # note: could be changed if target dset change + +dm: + train_hdfs: + - clotho_dev_resample_mean_convnext_ident_bl.hdf + - audiocaps_train_v6_resample_mean_convnext_ident_bl.hdf + - macs_full_resample_mean_convnext_ident_bl.hdf + - wavcaps_as_noac_resample_mean_convnext_ident_bl.hdf + - wavcaps_bbc_resample_mean_convnext_ident_bl.hdf + - wavcaps_sb_resample_mean_convnext_ident_bl.hdf + - wavcaps_fsd_nocl_resample_mean_convnext_ident_bl.hdf + val_hdfs: + - clotho_val_resample_mean_convnext_ident_bl.hdf + test_hdfs: + - audiocaps_val_resample_mean_convnext_ident_bl.hdf + - audiocaps_test_resample_mean_convnext_ident_bl.hdf + - clotho_val_resample_mean_convnext_ident_bl.hdf + - clotho_eval_resample_mean_convnext_ident_bl.hdf + predict_hdfs: + - clotho_test_resample_mean_convnext_ident_bl.hdf + - clotho_analysis_resample_mean_convnext_ident_bl.hdf + + main_hdf_min: clotho_dev_resample_mean_convnext_ident_bl.hdf + +pl: + proj_name: "lin768" + +trainer: + reload_dataloaders_every_n_epochs: 1 # /!\ to be used with main_hdf_min to sample different other data across epochs diff --git a/src/conf/expt/clotho_cnext_bl.yaml b/src/conf/expt/clotho_cnext_bl.yaml new file mode 100644 index 000000000..cf1d4c3d2 --- /dev/null +++ b/src/conf/expt/clotho_cnext_bl.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 + +dm: + train_hdfs: + - clotho_dev_resample_mean_convnext_ident_bl.hdf + val_hdfs: + - clotho_val_resample_mean_convnext_ident_bl.hdf + test_hdfs: + - audiocaps_val_resample_mean_convnext_ident_bl.hdf + - audiocaps_test_resample_mean_convnext_ident_bl.hdf + - clotho_val_resample_mean_convnext_ident_bl.hdf + - clotho_eval_resample_mean_convnext_ident_bl.hdf + predict_hdfs: + - clotho_test_resample_mean_convnext_ident_bl.hdf + - clotho_analysis_resample_mean_convnext_ident_bl.hdf + +pl: + proj_name: "lin768" diff --git a/src/conf/expt/clotho_cnext_nobl.yaml b/src/conf/expt/clotho_cnext_nobl.yaml new file mode 100644 index 000000000..d87f850e5 --- /dev/null +++ b/src/conf/expt/clotho_cnext_nobl.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 + +dm: + train_hdfs: + - clotho_dev_resample_mean_convnext_ident_nobl.hdf + val_hdfs: + - clotho_val_resample_mean_convnext_ident_nobl.hdf + test_hdfs: + - audiocaps_val_resample_mean_convnext_ident_nobl.hdf + - audiocaps_test_resample_mean_convnext_ident_nobl.hdf + - clotho_val_resample_mean_convnext_ident_nobl.hdf + - clotho_eval_resample_mean_convnext_ident_nobl.hdf + predict_hdfs: + - clotho_test_resample_mean_convnext_ident_nobl.hdf + - clotho_analysis_resample_mean_convnext_ident_nobl.hdf + +pl: + proj_name: "lin768" diff --git a/src/conf/expt/clotho_cnn10.yaml b/src/conf/expt/clotho_cnn10.yaml new file mode 100644 index 000000000..90fb04130 --- /dev/null +++ b/src/conf/expt/clotho_cnn10.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 + +dm: + train_hdfs: + - clotho_dev_resample_mean_cnn10_ident.hdf + val_hdfs: + - clotho_val_resample_mean_cnn10_ident.hdf + test_hdfs: + - audiocaps_val_resample_mean_cnn10_ident.hdf + - audiocaps_test_resample_mean_cnn10_ident.hdf + - clotho_val_resample_mean_cnn10_ident.hdf + - clotho_eval_resample_mean_cnn10_ident.hdf + predict_hdfs: + - clotho_test_resample_mean_cnn10_ident.hdf + - clotho_analysis_resample_mean_cnn10_ident.hdf + +pl: + proj_name: "lin512" diff --git a/src/conf/expt/clotho_cnn14_att.yaml b/src/conf/expt/clotho_cnn14_att.yaml new file mode 100644 index 000000000..a264bb315 --- /dev/null +++ b/src/conf/expt/clotho_cnn14_att.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 + +dm: + train_hdfs: + - clotho_dev_resample_mean_cnn14_att_ident.hdf + val_hdfs: + - clotho_val_resample_mean_cnn14_att_ident.hdf + test_hdfs: + - audiocaps_val_resample_mean_cnn14_att_ident.hdf + - audiocaps_test_resample_mean_cnn14_att_ident.hdf + - clotho_val_resample_mean_cnn14_att_ident.hdf + - clotho_eval_resample_mean_cnn14_att_ident.hdf + predict_hdfs: + - clotho_test_resample_mean_cnn14_att_ident.hdf + - clotho_analysis_resample_mean_cnn14_att_ident.hdf + +pl: + proj_name: "lin2048" diff --git a/src/conf/expt/clotho_cnn14_bl.yaml b/src/conf/expt/clotho_cnn14_bl.yaml new file mode 100644 index 000000000..6c1ec407e --- /dev/null +++ b/src/conf/expt/clotho_cnn14_bl.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 + +dm: + train_hdfs: + - clotho_dev_resample_mean_cnn14_ident_bl.hdf + val_hdfs: + - clotho_val_resample_mean_cnn14_ident_bl.hdf + test_hdfs: + - audiocaps_val_resample_mean_cnn14_ident_bl.hdf + - audiocaps_test_resample_mean_cnn14_ident_bl.hdf + - clotho_val_resample_mean_cnn14_ident_bl.hdf + - clotho_eval_resample_mean_cnn14_ident_bl.hdf + predict_hdfs: + - clotho_test_resample_mean_cnn14_ident_bl.hdf + - clotho_analysis_resample_mean_cnn14_ident_bl.hdf + +pl: + proj_name: "lin2048" diff --git a/src/conf/expt/clotho_cnn14_nobl.yaml b/src/conf/expt/clotho_cnn14_nobl.yaml new file mode 100644 index 000000000..81c67a1f4 --- /dev/null +++ b/src/conf/expt/clotho_cnn14_nobl.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 + +dm: + train_hdfs: + - clotho_dev_resample_mean_cnn14_ident_nobl.hdf + val_hdfs: + - clotho_val_resample_mean_cnn14_ident_nobl.hdf + test_hdfs: + - audiocaps_val_resample_mean_cnn14_ident_nobl.hdf + - audiocaps_test_resample_mean_cnn14_ident_nobl.hdf + - clotho_val_resample_mean_cnn14_ident_nobl.hdf + - clotho_eval_resample_mean_cnn14_ident_nobl.hdf + predict_hdfs: + - clotho_test_resample_mean_cnn14_ident_nobl.hdf + - clotho_analysis_resample_mean_cnn14_ident_nobl.hdf + +pl: + proj_name: "lin2048" diff --git a/src/conf/expt/clotho_cnn14_pann.yaml b/src/conf/expt/clotho_cnn14_pann.yaml new file mode 100644 index 000000000..a01ddc7db --- /dev/null +++ b/src/conf/expt/clotho_cnn14_pann.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - hp_clotho_v2 + +dm: + train_hdfs: + - clotho_dev_resample_mean_cnn14_ident_pann.hdf + val_hdfs: + - clotho_val_resample_mean_cnn14_ident_pann.hdf + test_hdfs: + - audiocaps_val_resample_mean_cnn14_ident_pann.hdf + - audiocaps_test_resample_mean_cnn14_ident_pann.hdf + - clotho_val_resample_mean_cnn14_ident_pann.hdf + - clotho_eval_resample_mean_cnn14_ident_pann.hdf + predict_hdfs: + - clotho_test_resample_mean_cnn14_ident_pann.hdf + - clotho_analysis_resample_mean_cnn14_ident_pann.hdf + +pl: + proj_name: "lin2048" diff --git a/src/conf/expt/hp_audiocaps_clotho_v2.yaml b/src/conf/expt/hp_audiocaps_clotho_v2.yaml new file mode 100644 index 000000000..5e75337f8 --- /dev/null +++ b/src/conf/expt/hp_audiocaps_clotho_v2.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +defaults: + - override /ckpts: fense + - override /audio_t@audio_t.train: spec_aug_ratio_emb + +pl: + label_smoothing: 0.2 + beam_size: 3 + max_pred_size: 20 + gen_val_cands: "generate" + +slurm: + time: "10:00:00" + +trainer: + gradient_clip_val: 1 + max_epochs: 400 + +audio_t: + train: + time_ratios: [0.0,0.1] + time_stripes_num: 2 + freq_ratios: [0.0,0.1] + freq_stripes_num: 2 + time_dim: -2 + freq_dim: -1 + p: 1.0 diff --git a/src/conf/expt/hp_audiocaps_v2.yaml b/src/conf/expt/hp_audiocaps_v2.yaml new file mode 100644 index 000000000..6c1ea98a2 --- /dev/null +++ b/src/conf/expt/hp_audiocaps_v2.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +defaults: + - override /ckpts: fense + - override /audio_t@audio_t.train: spec_aug_ratio_emb + +pl: + label_smoothing: 0.1 + beam_size: 2 + max_pred_size: 30 + gen_val_cands: "generate" + +slurm: + time: "4:00:00" + +trainer: + gradient_clip_val: 10 + max_epochs: 100 + +audio_t: + train: + time_ratios: [0.0,0.1] + time_stripes_num: 2 + freq_ratios: [0.0,0.1] + freq_stripes_num: 2 + time_dim: -2 + freq_dim: -1 + p: 1.0 diff --git a/src/conf/expt/hp_clotho_v1.yaml b/src/conf/expt/hp_clotho_v1.yaml new file mode 100644 index 000000000..9506e45b3 --- /dev/null +++ b/src/conf/expt/hp_clotho_v1.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +defaults: + - override /ckpts: fense + - override /audio_t@audio_t.train: spec_aug_emb + +pl: + label_smoothing: 0.2 + beam_size: 3 + max_pred_size: 20 + gen_val_cands: "generate" + +slurm: + time: "6:00:00" + +trainer: + gradient_clip_val: 1 + max_epochs: 400 + +audio_t: + train: + time_max_width: 4 + time_stripes_num: 6 + freq_max_width: 2 + freq_stripes_num: 2 + time_dim: -2 + freq_dim: -1 + p: 1.0 diff --git a/src/conf/expt/hp_clotho_v2.yaml b/src/conf/expt/hp_clotho_v2.yaml new file mode 100644 index 000000000..369461d01 --- /dev/null +++ b/src/conf/expt/hp_clotho_v2.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +defaults: + - override /ckpts: fense + - override /audio_t@audio_t.train: spec_aug_ratio_emb + +pl: + label_smoothing: 0.2 + beam_size: 3 + max_pred_size: 20 + gen_val_cands: "generate" + +slurm: + time: "6:00:00" + +trainer: + gradient_clip_val: 1 + max_epochs: 400 + +audio_t: + train: + time_ratios: [0.0,0.1] + time_stripes_num: 2 + freq_ratios: [0.0,0.1] + freq_stripes_num: 2 + time_dim: -2 + freq_dim: -1 + p: 1.0 diff --git a/src/conf/expt/none.yaml b/src/conf/expt/none.yaml new file mode 100644 index 000000000..03bfe3dba --- /dev/null +++ b/src/conf/expt/none.yaml @@ -0,0 +1 @@ +# @package _global_ diff --git a/src/conf/expt/task_ds_src_camw.yaml b/src/conf/expt/task_ds_src_camw.yaml new file mode 100644 index 000000000..3fea073ae --- /dev/null +++ b/src/conf/expt/task_ds_src_camw.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +defaults: + - override /pl: conette + +pl: + task_names: [clotho, audiocaps, macs, wavcaps_audioset_sl, wavcaps_bbc_sound_effects, wavcaps_freesound, wavcaps_soundbible] + task_mode: ds_src + gen_val_cands: generate + gen_test_cands: generate + +dm: + train_cols: [audio, audio_shape, captions, dataset, source] + val_cols: [audio, audio_shape, captions, dataset, source] + test_cols: [audio, audio_shape, captions, dataset, subset, fname, index, source] diff --git a/src/conf/hydra/custom.yaml b/src/conf/hydra/custom.yaml new file mode 100644 index 000000000..245fbbc6c --- /dev/null +++ b/src/conf/hydra/custom.yaml @@ -0,0 +1,24 @@ +# @package hydra + +defaults: + - override hydra_logging: colorlog + - override job_logging: custom # redirect log file to output_subdir + - override launcher: basic + - override sweeper: basic + - _self_ + +job: + # note: add default num for single-runs + num: 0 + +# Set hydra working dir for single runs +run: + dir: "${path.log_root}/${hydra.job.name}-${datetime}-${tagv}" + +# Set hydra working dir for multiruns +sweep: + dir: "${path.log_root}/${hydra.job.name}-${datetime}-${tagv}" + subdir: "${hydra.job.num}-${subtagv}" + +# Set args save in board dir +output_subdir: "${hydra.sweep.subdir}/hydra" diff --git a/src/conf/hydra/job_logging/custom.yaml b/src/conf/hydra/job_logging/custom.yaml new file mode 100644 index 000000000..59acfdb30 --- /dev/null +++ b/src/conf/hydra/job_logging/custom.yaml @@ -0,0 +1,34 @@ +# @package hydra.job_logging + +# Note: redirect log file to output_subdir + +version: 1 +formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + colorlog: + '()': 'colorlog.ColoredFormatter' + format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s' + log_colors: + DEBUG: purple + INFO: green + WARNING: yellow + ERROR: red + CRITICAL: red + +handlers: + console: + class: logging.StreamHandler + formatter: colorlog + stream: ext://sys.stdout + file: + class: conette.utils.hydra.CustomFileHandler + formatter: colorlog + filename: "${hydra.sweep.dir}/${hydra.sweep.subdir}/logs/outputs.log" + +root: + level: INFO + handlers: + - console + - file +disable_existing_loggers: false diff --git a/src/conf/launcher/local.yaml b/src/conf/launcher/local.yaml new file mode 100644 index 000000000..afe50834e --- /dev/null +++ b/src/conf/launcher/local.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +defaults: + - /path: local + - _self_ + + +slurm: + # --- Config + + # Account for sbatch + account: null + # Nodes constraints (also -C) + constraint: null + # Number of CPUs (also -c) + cpus_per_task: null + # Cores selection (also -m) + distribution: null + # Path to the stderr file (also -e) + error: "${path.log_root}/slurm/%j-${slurm.job_name}.out" + # GPU resources + gres: "gpu:${slurm.gpus}" + # GPU resources flags + gres_flags: null + # Job name (also -J) + job_name: "${tagv}-${subtagv}" + # Global RAM memory to use. Memory format : number[K|M|G|T]. If "0", no memory limit, use all of memory in node. + mem: null + # Memory per CPU to use. + mem_per_cpu: null + # Number of nodes (also -N) + nodes: 1 + # Number of tasks (also -n) + ntasks_per_node: 1 + # Path to the stdout file (also -o) + output: "${path.log_root}/slurm/%j-${slurm.job_name}.out" + # Partition (also -p) + partition: null + # Quality Of Service queue (also -q) + qos: null + # Time format : days-hours:minutes:seconds. If "0", no time limit. Example for 3 days : 3-00:00:00 (also -t) + time: 0 + + # --- Other + + # Number of GPUs + gpus: 1 + # Module commands executed before srun + module_cmds: "" + # Sbatch command used to execute the sbatch script + sbatch: "bash" + # Srun prefix used to run python command + srun: "" + # Test sbatch file without launching the job + test_only: false diff --git a/src/conf/logger/custom_tb.yaml b/src/conf/logger/custom_tb.yaml new file mode 100644 index 000000000..fd24e39c2 --- /dev/null +++ b/src/conf/logger/custom_tb.yaml @@ -0,0 +1,8 @@ +# @package logger + +_target_: "conette.utils.custom_logger.CustomTensorboardLogger" + +save_dir: "${hydra:sweep.dir}/${hydra:sweep.subdir}" +name: null +version: "" +default_hp_metric: true diff --git a/src/conf/path/local.yaml b/src/conf/path/local.yaml new file mode 100644 index 000000000..f9b009216 --- /dev/null +++ b/src/conf/path/local.yaml @@ -0,0 +1,26 @@ +# @package path + +# Local default paths + +# Cache directory +cache: "${oc.env:HOME}/.cache" +# Default data directory +data: "data" +# To redirect gensim data and models +gensim: null +# For audiocaps download +ffmpeg: "ffmpeg" +# For SPICE & METEOR metrics +java: "java" +# Parent log directory (cannot contains environment variables) +log_root: "logs" +# Path to micromamba +micromamba: "micromamba" +# Global tmp directory +tmp: "/tmp" +# To redirect torch models (default: ~/.cache/torch/hub) +torch_hub: null +# For audiocaps download +ytdl: "yt-dlp" +# For FSD50K download +zip: "zip" diff --git a/src/conf/pl/baseline.yaml b/src/conf/pl/baseline.yaml new file mode 100644 index 000000000..e91f8666b --- /dev/null +++ b/src/conf/pl/baseline.yaml @@ -0,0 +1,34 @@ +# @package pl + +_target_: "conette.pl_modules.baseline.BaselinePLM" + +# Model params +label_smoothing: 0.1 +gen_val_cands: "none" +mixup_alpha: 0.4 +# Encoder params +proj_name: "lin768" +# Decoder params +nhead: 8 +d_model: 256 +num_decoder_layers: 6 +decoder_dropout_p: 0.2 +dim_feedforward: 2048 +acti_name: "gelu" +# Generate params +min_pred_size: 3 +max_pred_size: 30 +beam_size: 2 +# Optimizer params +optim_name: "AdamW" +lr: 0.0005 +weight_decay: 2.0 +betas: [0.9, 0.999] +eps: 1e-8 +use_custom_wd: true +# Scheduler params +sched_name: "cos_decay" +sched_n_steps: ${trainer.max_epochs} +sched_interval: "epoch" +# Other params +verbose: ${verbose} diff --git a/src/conf/pl/conette.yaml b/src/conf/pl/conette.yaml new file mode 100644 index 000000000..0506f2f1b --- /dev/null +++ b/src/conf/pl/conette.yaml @@ -0,0 +1,37 @@ +# @package pl + +_target_: "conette.pl_modules.conette.CoNeTTEPLM" + +task_mode: "ds_src" +task_names: [clotho, audiocaps, macs, wavcaps_audioset_sl, wavcaps_bbc_sound_effects, wavcaps_freesound, wavcaps_soundbible] +gen_test_cands: "generate" +# Model params +label_smoothing: 0.1 +gen_val_cands: "none" +mixup_alpha: 0.4 +# Encoder params +proj_name: "lin768" +# Decoder params +nhead: 8 +d_model: 256 +num_decoder_layers: 6 +decoder_dropout_p: 0.2 +dim_feedforward: 2048 +acti_name: "gelu" +# Generate params +min_pred_size: 3 +max_pred_size: 30 +beam_size: 2 +# Optimizer params +optim_name: "AdamW" +lr: 0.0005 +weight_decay: 2.0 +betas: [0.9, 0.999] +eps: 1e-8 +use_custom_wd: true +# Scheduler params +sched_name: "cos_decay" +sched_n_steps: ${trainer.max_epochs} +sched_interval: "epoch" +# Other params +verbose: ${verbose} diff --git a/src/conf/prepare.yaml b/src/conf/prepare.yaml new file mode 100644 index 000000000..cc6fad97d --- /dev/null +++ b/src/conf/prepare.yaml @@ -0,0 +1,82 @@ +# @package _global_ + +# Group params +defaults: + - audio_t: resample_mean_convnext + - data: clotho + - hydra: custom + - launcher: local + - logger: custom_tb + - text_t: ident + - override hydra/job_logging: custom # redirect log file to output_subdir + - _self_ + +# --- Common params + +# bool +debug: false +# str | None +git_hash: null +# str +posttag: "" +# str +pretag: "" +# int +seed: 1234 +# str | None +sharing_strategy: null +# str | list[str] | "auto" +subtagk: "auto" +# str | list[str] +tagk: [] +# int +verbose: 1 + +# --- Auto params + +# str +datetime: ${now:%Y.%m.%d-%H.%M.%S} +# str +tagv: "${get_tag:}" +# str +subtagv: "${get_subtag:}" + +# --- Other params + +# bool +default: true +nltk: ${default} +spacy: ${default} +pann: false +audioset_indices: ${default} +ptb_tokenizer: ${default} +meteor: ${default} +spice: ${default} +fense: ${default} +cnext: ${default} + +# bool +pack_to_hdf: ${default} +# bool +overwrite_hdf: false +# Optional[str] +post_hdf_name: null +# bool +csum_in_hdf_name: false + +# Override data.download option +data: + # bool + download: ${default} + +datafilter: + # float + min_audio_size: 0.0 + # float | "inf" + max_audio_size: inf + # Optional[int] + imin: null + # Optional[int] + imax: null + # Optional[int] + sr: null diff --git a/src/conf/retrieve.yaml b/src/conf/retrieve.yaml new file mode 100644 index 000000000..287b6a8a9 --- /dev/null +++ b/src/conf/retrieve.yaml @@ -0,0 +1,57 @@ +# @package _global_ + +defaults: + - hydra: custom + - launcher: ??? + - _self_ + +# --- Common params + +# bool +debug: false +# str | None +git_hash: null +# str +posttag: "" +# str +pretag: "" +# int +seed: 1234 +# str | None +sharing_strategy: null +# str | list[str] | "auto" +subtagk: "auto" +# str | list[str] +tagk: [] +# int +verbose: 1 + +# --- Auto params + +# str +datetime: ${now:%Y.%m.%d-%H.%M.%S} +# str +tagv: ${get_tag:} +# str +subtagv: ${get_subtag:} + +# --- Other params + +# str | list[str] +resume: ??? +# str | list[str] +hdf_fnames: ??? +# int +bsize: 512 +# int +n_workers: ${slurm.cpus_per_task} +# str +device: "auto" +# str | list[str] +t2a_modes: + - "loss" + - "scaled_loss" +# str | list[str] +a2t_modes: + - "loss" + - "scaled_loss" diff --git a/src/conf/text_t/ident.yaml b/src/conf/text_t/ident.yaml new file mode 100644 index 000000000..53895f62f --- /dev/null +++ b/src/conf/text_t/ident.yaml @@ -0,0 +1,3 @@ +# @package text_t + +_target_: "torch.nn.Identity" diff --git a/src/conf/text_t/none.yaml b/src/conf/text_t/none.yaml new file mode 100644 index 000000000..5c4df758f --- /dev/null +++ b/src/conf/text_t/none.yaml @@ -0,0 +1,3 @@ +# @package text_t + +_target_: "conette.utils.hydra.get_none" diff --git a/src/conf/tok/spacy.yaml b/src/conf/tok/spacy.yaml new file mode 100644 index 000000000..a2bead25f --- /dev/null +++ b/src/conf/tok/spacy.yaml @@ -0,0 +1,12 @@ +# @package tok + +_target_: "conette.tokenization.aac_tokenizer.AACTokenizer" + +level: "word" +lowercase: true +punctuation_mode: "remove" +normalize: true + +backend: "spacy" + +model_name: "en_core_web_sm" diff --git a/src/conf/tok/test.yaml b/src/conf/tok/test.yaml new file mode 100644 index 000000000..093cf5e72 --- /dev/null +++ b/src/conf/tok/test.yaml @@ -0,0 +1,14 @@ +# @package tok + +_target_: "conette.tokenization.aac_tokenizer.AACTokenizer" + +level: "word" +lowercase: false +punctuation_mode: "remove" # note: PTB tokenizer already delete punctuations tokens +normalize: true + +backend: "ptb" + +cache_path: "${path.cache}" +java_path: "${path.java}" +tmp_path: "${path.tmp}" diff --git a/src/conf/train.yaml b/src/conf/train.yaml new file mode 100644 index 000000000..794b75cb2 --- /dev/null +++ b/src/conf/train.yaml @@ -0,0 +1,93 @@ +# @package _global_ + +defaults: + - audio_t@audio_t.train: spec_aug_ratio_emb + - audio_t@audio_t.val: none + - audio_t@audio_t.test: none + - ckpts: loss + - dm: hdf + - evaluator: aac + - hydra: custom + - launcher: local + - logger: custom_tb + - pl: conette + - tok@train_tok: spacy + - tok@test_tok: test + - trainer: fit_test + - _self_ + # note: expt must be the last in defaults list + - expt: clotho_cnext_bl + +# --- Common params + +# bool +debug: false +# str | None +git_hash: null +# str +posttag: "" +# str +pretag: "" +# int +seed: 1234 +# str | None +sharing_strategy: null +# str | list[str] | "auto" +subtagk: "auto" +# str | list[str] +tagk: [] +# int +verbose: 1 + +# --- Auto params + +# str +datetime: ${now:%Y.%m.%d-%H.%M.%S} +# str +tagv: ${get_tag:} +# str +subtagv: ${get_subtag:} + +# --- Other params + +# str | None +resume: null +# str | None +resume_2: null +# bool +strict_resume: true +# bool +resume_before_setup: false +# bool +save: true +# bool +val_on_start: true +# bool +test_on_start: true +# str | None +out_crit: null +# float +out_default: -1.0 +# str | list[str] +val_metrics_keys: [] +# str | list[str] +ign_weights: [] +# bool +enable_dspeed: true + +testing: + # list[str] + # Can contains: "best", "last", "none", "swa" + run: [best] + + # Note: this param is ignored if "swa" is not in testing.run + swa: + _target_: "pytorch_lightning.callbacks.StochasticWeightAveraging" + # int | float + swa_epoch_start: 0.8 + # float | List[float] | None + swa_lrs: null + # int + annealing_epochs: 10 + # str: "cos", "linear" + annealing_strategy: "cos" diff --git a/src/conf/trainer/dev.yaml b/src/conf/trainer/dev.yaml new file mode 100644 index 000000000..908abceee --- /dev/null +++ b/src/conf/trainer/dev.yaml @@ -0,0 +1,10 @@ +# @package trainer + +defaults: + - fit_test + +detect_anomaly: true +enable_checkpointing: false +fast_dev_run: true +log_every_n_steps: 1 +max_epochs: 5 diff --git a/src/conf/trainer/fit.yaml b/src/conf/trainer/fit.yaml new file mode 100644 index 000000000..182552d0e --- /dev/null +++ b/src/conf/trainer/fit.yaml @@ -0,0 +1,7 @@ +# @package trainer + +defaults: + - fit_test + +limit_predict_batches: 0 +limit_test_batches: 0 diff --git a/src/conf/trainer/fit2.yaml b/src/conf/trainer/fit2.yaml new file mode 100644 index 000000000..d58ac9e5c --- /dev/null +++ b/src/conf/trainer/fit2.yaml @@ -0,0 +1,11 @@ +# @package trainer + +defaults: + - fit_test + +limit_predict_batches: 0 +limit_test_batches: 0 +limit_train_batches: 2 +limit_val_batches: 2 +log_every_n_steps: 1 +max_epochs: 3 diff --git a/src/conf/trainer/fit_test.yaml b/src/conf/trainer/fit_test.yaml new file mode 100644 index 000000000..13a58716c --- /dev/null +++ b/src/conf/trainer/fit_test.yaml @@ -0,0 +1,39 @@ +# @package trainer + +defaults: + - plugins: none + - profiler: none + - strategy: none + +_target_: "pytorch_lightning.trainer.Trainer" + +accelerator: "gpu" +accumulate_grad_batches: 1 +auto_scale_batch_size: null +benchmark: false +detect_anomaly: false +deterministic: false +devices: ${slurm.gpus} +enable_checkpointing: ${save} +enable_model_summary: false +fast_dev_run: false +gradient_clip_algorithm: "norm" +gradient_clip_val: 10 +limit_predict_batches: null +limit_test_batches: null +limit_train_batches: null +limit_val_batches: null +log_every_n_steps: 5 +max_epochs: 100 +max_steps: -1 +move_metrics_to_cpu: false +# multiple_trainloader_mode: "max_size_cycle", "min_size" +multiple_trainloader_mode: "max_size_cycle" +num_nodes: 1 +# precision: 32, 16 +precision: 32 +reload_dataloaders_every_n_epochs: 0 +resume_from_checkpoint: null +num_sanity_val_steps: 0 +track_grad_norm: -1 +val_check_interval: null diff --git a/src/conf/trainer/lim2.yaml b/src/conf/trainer/lim2.yaml new file mode 100644 index 000000000..745b3a4c6 --- /dev/null +++ b/src/conf/trainer/lim2.yaml @@ -0,0 +1,11 @@ +# @package trainer + +defaults: + - fit_test + +limit_predict_batches: 2 +limit_test_batches: 2 +limit_train_batches: 2 +limit_val_batches: 2 +log_every_n_steps: 1 +max_epochs: 3 diff --git a/src/conf/trainer/plugins/lightning.yaml b/src/conf/trainer/plugins/lightning.yaml new file mode 100644 index 000000000..0666389c0 --- /dev/null +++ b/src/conf/trainer/plugins/lightning.yaml @@ -0,0 +1,3 @@ +# @package trainer.plugins + +_target_: "pytorch_lightning.plugins.environments.LightningEnvironment" diff --git a/src/conf/trainer/plugins/none.yaml b/src/conf/trainer/plugins/none.yaml new file mode 100644 index 000000000..fc0699815 --- /dev/null +++ b/src/conf/trainer/plugins/none.yaml @@ -0,0 +1,3 @@ +# @package trainer.plugins + +_target_: "conette.utils.hydra.get_none" diff --git a/src/conf/trainer/plugins/slurm.yaml b/src/conf/trainer/plugins/slurm.yaml new file mode 100644 index 000000000..8940aa22d --- /dev/null +++ b/src/conf/trainer/plugins/slurm.yaml @@ -0,0 +1,5 @@ +# @package trainer.plugins + +_target_: "pytorch_lightning.plugins.environments.SLURMEnvironment" + +auto_requeue: false diff --git a/src/conf/trainer/predict2.yaml b/src/conf/trainer/predict2.yaml new file mode 100644 index 000000000..90f340709 --- /dev/null +++ b/src/conf/trainer/predict2.yaml @@ -0,0 +1,11 @@ +# @package trainer + +defaults: + - fit_test + +limit_predict_batches: 2 +limit_test_batches: 0 +limit_train_batches: 0 +limit_val_batches: 0 +log_every_n_steps: 1 +max_epochs: 0 diff --git a/src/conf/trainer/profiler/none.yaml b/src/conf/trainer/profiler/none.yaml new file mode 100644 index 000000000..625cee92e --- /dev/null +++ b/src/conf/trainer/profiler/none.yaml @@ -0,0 +1,3 @@ +# @package trainer.profiler + +_target_: "conette.utils.hydra.get_none" diff --git a/src/conf/trainer/profiler/pytorch.yaml b/src/conf/trainer/profiler/pytorch.yaml new file mode 100644 index 000000000..864eee64e --- /dev/null +++ b/src/conf/trainer/profiler/pytorch.yaml @@ -0,0 +1,11 @@ +# @package trainer.profiler + +_target_: "pytorch_lightning.profilers.PyTorchProfiler" + +dirpath: "${hydra.sweep.dir}/${hydra.sweep.subdir}" +filename: "pytorch_profiler" +record_shapes: True +profile_memory: True +with_stack: True +with_flops: True +with_modules: True diff --git a/src/conf/trainer/strategy/ddp.yaml b/src/conf/trainer/strategy/ddp.yaml new file mode 100644 index 000000000..729300883 --- /dev/null +++ b/src/conf/trainer/strategy/ddp.yaml @@ -0,0 +1,8 @@ +# @package trainer.strategy + +_target_: "pytorch_lightning.strategies.DDPStrategy" + +find_unused_parameters: false +static_graph: true +# null, "nccl", "mpi", "gloo" +# process_group_backend: null diff --git a/src/conf/trainer/strategy/ddp_spawn.yaml b/src/conf/trainer/strategy/ddp_spawn.yaml new file mode 100644 index 000000000..858b615be --- /dev/null +++ b/src/conf/trainer/strategy/ddp_spawn.yaml @@ -0,0 +1,3 @@ +# @package trainer.strategy + +_target_: "pytorch_lightning.strategies.DDPSpawnStrategy" diff --git a/src/conf/trainer/strategy/none.yaml b/src/conf/trainer/strategy/none.yaml new file mode 100644 index 000000000..27154e09a --- /dev/null +++ b/src/conf/trainer/strategy/none.yaml @@ -0,0 +1,3 @@ +# @package trainer.strategy + +_target_: "conette.utils.hydra.get_none" diff --git a/src/conf/trainer/test.yaml b/src/conf/trainer/test.yaml new file mode 100644 index 000000000..01acd68e9 --- /dev/null +++ b/src/conf/trainer/test.yaml @@ -0,0 +1,11 @@ +# @package trainer + +defaults: + - fit_test + +limit_predict_batches: null +limit_test_batches: null +limit_train_batches: 0 +limit_val_batches: 0 +log_every_n_steps: 1 +max_epochs: 0 diff --git a/src/conf/trainer/test2.yaml b/src/conf/trainer/test2.yaml new file mode 100644 index 000000000..674b07f17 --- /dev/null +++ b/src/conf/trainer/test2.yaml @@ -0,0 +1,11 @@ +# @package trainer + +defaults: + - fit_test + +limit_predict_batches: 2 +limit_test_batches: 2 +limit_train_batches: 0 +limit_val_batches: 0 +log_every_n_steps: 1 +max_epochs: 0 diff --git a/tests/test_inference.py b/tests/test_inference.py index b1d3f4cb8..d858301e6 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -23,6 +23,7 @@ def test_example_1(self) -> None: outputs = self.model(path) candidate = outputs["cands"][0] + # expected: "rain is pouring down and people are talking in the background" self.assertIsInstance(candidate, str) def test_example_2(self) -> None: @@ -56,6 +57,14 @@ def test_forbid_rep_mode(self) -> None: self.assertIsInstance(cand, str) + def test_tags(self) -> None: + path = get_sample_path() + outputs = self.model(path, task="clotho", forbid_rep_mode="none", beam_size=1) + tags = outputs["tags"] + + assert tags is not None + self.assertIsInstance(tags, list) + if __name__ == "__main__": unittest.main()