Skip to content

Commit

Permalink
Merge branch 'facebookresearch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
0xlws authored Dec 17, 2023
2 parents f317a5a + 5c7ea98 commit e2d3f45
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 22 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

Adding stereo models.

Fixed the commitment loss, which was until now only applied to the first RVQ layer.

Removed compression model state from the LM checkpoints, for consistency, it
should always be loaded from the original `compression_model_checkpoint`.


## [1.1.0] - 2023-11-06

Expand Down
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
pip install 'torch>=2.0'
python -m pip install 'torch==2.1.0'
# You might need the following before trying to install the packages
python -m pip install setuptools wheel
# Then proceed to one of the following
pip install -U audiocraft # stable release
pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
python -m pip install -U audiocraft # stable release
python -m pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
python -m pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
```

We also recommend having `ffmpeg` installed, either through your system or Anaconda:
Expand Down Expand Up @@ -72,11 +74,11 @@ Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and w

For the general framework of AudioCraft, please cite the following.
```
@article{copet2023simple,
@inproceedings{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
journal={arXiv preprint arXiv:2306.05284},
}
```

Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.2.0a1'
__version__ = '1.2.0a2'
1 change: 0 additions & 1 deletion audiocraft/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,6 @@ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforwar
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
# backward hook inside of FSDP...
layer._magma_checkpointed = True # type: ignore
assert layer.layer_drop == 0., "Need further checking" # type: ignore

def _apply_layer(self, layer, *args, **kwargs):
method = self.checkpointing
Expand Down
5 changes: 5 additions & 0 deletions audiocraft/quantization/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,16 @@ def forward(self, x, n_q: tp.Optional[int] = None):

for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
quantized = quantized.detach()
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)

if self.training:
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
quantized_out = x + (quantized_out - x).detach()

out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses

Expand Down
20 changes: 18 additions & 2 deletions audiocraft/solvers/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
from ..utils.cache import CachedBatchWriter, CachedBatchLoader
from ..utils.samples.manager import SampleManager
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once, model_hash


class MusicGenSolver(base.StandardSolver):
Expand Down Expand Up @@ -143,7 +143,7 @@ def build_model(self) -> None:
# initialize optimization
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
self.register_stateful('model', 'optimizer', 'lr_scheduler')
self.register_best_state('model')
self.autocast_dtype = {
'float16': torch.float16, 'bfloat16': torch.bfloat16
Expand Down Expand Up @@ -181,6 +181,22 @@ def load_state_dict(self, state: dict) -> None:
key = prefix + key
assert key not in model_state
model_state[key] = value
if 'compression_model' in state:
# We used to store the `compression_model` state in the checkpoint, however
# this is in general not needed, as the compression model should always be readable
# from the original `cfg.compression_model_checkpoint` location.
compression_model_state = state.pop('compression_model')
before_hash = model_hash(self.compression_model)
self.compression_model.load_state_dict(compression_model_state)
after_hash = model_hash(self.compression_model)
if before_hash != after_hash:
raise RuntimeError(
"The compression model state inside the checkpoint is different"
" from the one obtained from compression_model_checkpoint..."
"We do not support altering the compression model inside the LM "
"checkpoint as parts of the code, in particular for running eval post-training "
"will use the compression_model_checkpoint as the source of truth.")

super().load_state_dict(state)

def load_from_pretrained(self, name: str):
Expand Down
1 change: 1 addition & 0 deletions audiocraft/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def _load_one(self, index: int):
if isinstance(part[0], torch.Tensor):
out.append(torch.stack(part))
else:
assert isinstance(part, torch.Tensor)
out.append(part)
return out
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion demos/audiogen_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
" \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
" t = torch.arange(\n",
" int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
" wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
" wav = torch.cos(2 * math.pi * frequency * t)[None]\n",
" tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
" envelope = (tp >= 0.5).float()\n",
" return wav * envelope"
Expand Down
8 changes: 5 additions & 3 deletions demos/musicgen_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def _cleanup(self):
self.files.pop(0)
else:
break



file_cleaner = FileCleaner()


Expand All @@ -96,6 +95,9 @@ def load_model(version='facebook/musicgen-melody'):
global MODEL
print("Loading model", version)
if MODEL is None or MODEL.name != version:
# Clear PyTorch CUDA cache and delete model
del MODEL
torch.cuda.empty_cache()
MODEL = None # in case loading would crash
MODEL = MusicGen.get_pretrained(version)

Expand Down Expand Up @@ -256,7 +258,7 @@ def ui_full(launch_kwargs):
with gr.Column():
radio = gr.Radio(["file", "mic"], value="file",
label="Condition on a melody (optional) File or Mic")
melody = gr.Audio(source="upload", type="numpy", label="File",
melody = gr.Audio(sources=["upload"], type="numpy", label="File",
interactive=True, elem_id="melody-input")
with gr.Row():
submit = gr.Button("Submit")
Expand Down
8 changes: 4 additions & 4 deletions docs/MUSICGEN.md
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ Once you have launched some experiments, you can easily get access
to the Solver with the latest trained model using the following snippet.

```python
from audiocraft.solvers.musicgen import MusicGen
from audiocraft.solvers.musicgen import MusicGenSolver

solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8)
solver = MusicGenSolver.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8)
solver.model
solver.dataloaders
```
Expand Down Expand Up @@ -401,11 +401,11 @@ activations by sharding the optimizer state.

## Citation
```
@article{copet2023simple,
@inproceedings{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
journal={arXiv preprint arXiv:2306.05284},
}
```

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ julius
num2words
numpy
sentencepiece
spacy==3.5.2
torch>=2.0.0
spacy>=3.6.1
torch==2.1.0
torchaudio>=2.0.0
huggingface_hub
tqdm
Expand All @@ -20,4 +20,4 @@ librosa
gradio
torchmetrics
encodec
protobuf
protobuf
4 changes: 3 additions & 1 deletion tests/quantization/test_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
class TestResidualVectorQuantizer:

def test_rvq(self):
x = torch.randn(1, 16, 2048)
x = torch.randn(1, 16, 2048, requires_grad=True)
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
res = vq(x, 1.)
assert res.x.shape == torch.Size([1, 16, 2048])
res.x.sum().backward()
assert torch.allclose(x.grad.data, torch.ones(1))

0 comments on commit e2d3f45

Please sign in to comment.