diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 58e6cc3..4b40757 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,27 +1,27 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - + name: build on: - push: + push: branches: - main paths-ignore: - 'README.md' - 'README_CN.md' - 'docs/**' - + pull_request: paths-ignore: - 'README.md' - 'README_CN.md' - 'docs/**' - + concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true - + jobs: build_test: runs-on: ubuntu-18.04 @@ -39,14 +39,14 @@ jobs: # pip install torch numpy mmcv -i https://pypi.tuna.tsinghua.edu.cn/simple # pip install opencv-python>=3 yapf imageio scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple - # coverage run --source xrmogen/models -m pytest -s test/models + # coverage run --source xrmogen/models -m pytest -s test/models # coverage xml # coverage report -m - - name: Upload coverage to Codecov + - name: Upload coverage to Codecov uses: codecov/codecov-action@v2 with: files: ./coverage.xml flags: unittests env_vars: OS,PYTHON name: codecov-umbrella - fail_ci_if_error: false \ No newline at end of file + fail_ci_if_error: false diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f0a4623..4dc2868 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,11 +20,11 @@ jobs: sudo apt-add-repository ppa:brightbox/ruby-ng -y sudo apt-get update sudo apt-get install -y ruby2.7 - # pip install pre-commit - # pre-commit install + pip install pre-commit + pre-commit install - name: Linting - # run: pre-commit run --files xrmogen/* + run: pre-commit run --all-files - name: Check docstring coverage run: | pip install interrogate - # interrogate -vinmMI --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" -f 60 xrmogen/ + interrogate -vinmMI --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" -f 25 xrmogen/ diff --git a/.gitignore b/.gitignore index f98f487..d4b2f0f 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,4 @@ xrmogen.egg-info *.whl *ignore/ example/ -data/ \ No newline at end of file +data/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..1813618 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,44 @@ +repos: + - repo: https://gitlab.com/pycqa/flake8.git + rev: 3.8.3 + hooks: + - id: flake8 + - repo: https://github.com/asottile/seed-isort-config.git + rev: v2.2.0 + hooks: + - id: seed-isort-config + args: [--settings-path, ./] + - repo: https://github.com/PyCQA/isort.git + rev: 5.10.1 + hooks: + - id: isort + args: [--settings-file, ./setup.cfg] + - repo: https://github.com/pre-commit/mirrors-yapf.git + rev: v0.30.0 + hooks: + - id: yapf + - repo: https://github.com/pre-commit/pre-commit-hooks.git + rev: v3.1.0 + hooks: + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + exclude: .*/tests/data/ + - id: check-yaml + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: double-quote-string-fixer + - id: check-merge-conflict + - id: fix-encoding-pragma + args: ["--remove"] + - id: mixed-line-ending + args: ["--fix=lf"] + - repo: https://github.com/codespell-project/codespell + rev: v2.1.0 + hooks: + - id: codespell + args: ["--ignore-words-list", "ba"] + - repo: https://github.com/myint/docformatter.git + rev: v1.3.1 + hooks: + - id: docformatter + args: ["--in-place", "--wrap-descriptions", "79"] diff --git a/Dockerfile b/Dockerfile index ea3b926..08496d5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -49,4 +49,4 @@ RUN . /root/miniconda3/etc/profile.d/conda.sh && \ pip install tqdm && \ pip install xrprimer && \ pip install -e . && \ - pip cache purge \ No newline at end of file + pip cache purge diff --git a/README.md b/README.md index 60afeb2..283c467 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ The model structure can be customized through config files. To implement a new m - `train_step()`: forward method of the training mode. - `val_step()`: forward method of the testing mode. -- regestered as a dance model +- registered as a dance model To be specific, if we want to implement a new model, there are several things to do. @@ -79,7 +79,7 @@ To be specific, if we want to implement a new model, there are several things to def __init__(self, model_config): super().__init__() - + def forward(self, ...): .... @@ -113,7 +113,7 @@ To be specific, if we want to implement a new model, there are several things to XRMoGen uses `mmcv.runner.EpochBasedRunner` to control training and test. -In the training mode, the `max_epochs` in config file decide how many epochs to train. +In the training mode, the `max_epochs` in config file decide how many epochs to train. In test mode, `max_epochs` is forced to change to 1, which represents only 1 epoch to test. Validation frequency is set as `workflow` of config file: @@ -125,7 +125,7 @@ Validation frequency is set as `workflow` of config file: For example, to train Bailando (Dance Revolution), ```shell -python main.py --config configs/dance_rev.py +python main.py --config configs/dance_rev.py ``` Arguments are: @@ -134,7 +134,7 @@ Arguments are: ### Test To test relevant model, add `--test_only` tag after the config path. -We provide some pretrained weights to test (see [pretrained_model_list.md](docs/en/pretrained_model_list.md). Download the pretrained weights under a folder `./example`, and run +We provide some pretrained weights to test (see [pretrained_model_list.md](docs/en/pretrained_model_list.md). Download the pretrained weights under a folder `./example`, and run ```shell python main.py --config configs/bailando_test.py --test_only @@ -201,4 +201,3 @@ We wish that the framework and benchmark could serve the growing research commun - [XRMoCap](https://github.com/openxrlab/xrmocap): OpenXRLab Multi-view Motion Capture Toolbox and Benchmark. - [XRMoGen](https://github.com/openxrlab/xrmogen): OpenXRLab Human Motion Generation Toolbox and Benchmark. - [XRNeRF](https://github.com/openxrlab/xrnerf): OpenXRLab Neural Radiance Field (NeRF) Toolbox and Benchmark. - diff --git a/README_CN.md b/README_CN.md index cee2d9f..778793d 100644 --- a/README_CN.md +++ b/README_CN.md @@ -78,7 +78,7 @@ xrmogen def __init__(self, model_config): super().__init__() - + def forward(self, ...): .... @@ -124,7 +124,7 @@ XRMoGen 使用 `mmcv.runner.EpochBasedRunner` (以epoch为单位)去训练 比如,为了训练DanceRevolution模型,运行以下命令 ```shell -python main.py --config configs/dance_rev.py +python main.py --config configs/dance_rev.py ``` 参数: @@ -189,5 +189,3 @@ XRMoGen 是一款由学校和公司共同贡献的开源项目。我们感谢所 - [XRMoCap](https://github.com/openxrlab/xrmocap): OpenXRLab Multi-view Motion Capture Toolbox and Benchmark. - [XRMoGen](https://github.com/openxrlab/xrmogen): OpenXRLab Human Motion Generation Toolbox and Benchmark. - [XRNeRF](https://github.com/openxrlab/xrnerf): OpenXRLab Neural Radiance Field (NeRF) Toolbox and Benchmark. - - diff --git a/configs/bailando_test.py b/configs/bailando_test.py index 2b32601..d7dfb41 100644 --- a/configs/bailando_test.py +++ b/configs/bailando_test.py @@ -1,10 +1,9 @@ -from email import policy import os from datetime import datetime num_gpus = 1 -## optimizer +# optimizer method = 'bailando' phase = 'gpt' @@ -18,19 +17,19 @@ lr_config = dict(policy='step', step=[250, 400], gamma=0.1) checkpoint_config = dict(interval=20, by_epoch=True) log_level = 'INFO' -log_config = dict(interval=10, by_epoch=False, hooks=[dict(type='TextLoggerHook')]) +log_config = dict( + interval=10, by_epoch=False, hooks=[dict(type='TextLoggerHook')]) workflow = [('train', 20), ('val', 1)] # hooks -# 'params' are numeric type value, 'variables' are variables in local environment +# 'params' are numeric type value, +# 'variables' are variables in local environment train_hooks = [ - dict(type='SaveDancePKLHook', - params=dict()), + dict(type='SaveDancePKLHook', params=dict()), ] test_hooks = [ - dict(type='SaveTestDancePKLHook', - params=dict(save_folder='test')), + dict(type='SaveTestDancePKLHook', params=dict(save_folder='test')), ] # runner @@ -41,37 +40,42 @@ num_gpus = 1 distributed = 0 # multi-gpu work_dir = './bailando_test/'.format(phase) # noqa -timestamp = datetime.now().strftime("%d-%b-%H-%M") - +timestamp = datetime.now().strftime('%d-%b-%H-%M') load_from = os.path.join('./example/bailando.pth') -## dataset +# dataset -traindata_cfg = dict( +traindata_cfg = dict( data_dir='data/aistpp_train_wav', rotmat=False, seq_len=240, mode='train', move=8, external_wav='data/aistpp_music_feat_7.5fps', - external_wav_rate=8 -) + external_wav_rate=8) -testdata_cfg = dict( +testdata_cfg = dict( data_dir='data/aistpp_test_full_wav', rotmat=False, mode='test', move=8, external_wav='data/aistpp_music_feat_7.5fps', - external_wav_rate=8 -) + external_wav_rate=8) train_pipeline = [ - dict(type='ToTensor', enable=True, keys=['music', 'dance'],), + dict( + type='ToTensor', + enable=True, + keys=['music', 'dance'], + ), ] test_pipeline = [ - dict(type='ToTensor', enable=True, keys=['music', 'dance'],), + dict( + type='ToTensor', + enable=True, + keys=['music', 'dance'], + ), ] data = dict( @@ -95,23 +99,28 @@ ), ) - -##### model +# model model = dict( type='Bailando', model_config=dict( bailando_phase='gpt', - vqvae=dict( + vqvae=dict( up_half=dict( levels=1, - downs_t=[3,], - strides_t =[2,], + downs_t=[ + 3, + ], + strides_t=[ + 2, + ], emb_width=512, l_bins=512, l_mu=0.99, commit=0.02, - hvqvae_multipliers=[1,], + hvqvae_multipliers=[ + 1, + ], width=512, depth=3, m_conv=1.0, @@ -119,30 +128,33 @@ sample_length=240, use_bottleneck=True, joint_channel=3, - vqvae_reverse_decoder_dilation=True - ), + vqvae_reverse_decoder_dilation=True), down_half=dict( levels=1, - downs_t=[3,], - strides_t =[2,], - emb_width =512, - l_bins =512, - l_mu =0.99, - commit =0.02, - hvqvae_multipliers =[1,], + downs_t=[ + 3, + ], + strides_t=[ + 2, + ], + emb_width=512, + l_bins=512, + l_mu=0.99, + commit=0.02, + hvqvae_multipliers=[ + 1, + ], width=512, depth=3, - m_conv =1.0, - dilation_growth_rate =3, + m_conv=1.0, + dilation_growth_rate=3, sample_length=240, use_bottleneck=True, joint_channel=3, - vqvae_reverse_decoder_dilation=True - ), + vqvae_reverse_decoder_dilation=True), use_bottleneck=True, joint_channel=3, ), - gpt=dict( block_size=29, base=dict( @@ -154,10 +166,9 @@ block_size=29, n_layer=6, n_head=12, - n_embd=768 , + n_embd=768, n_music=438, - n_music_emb=768 - ), + n_music_emb=768), head=dict( embd_pdrop=0.1, resid_pdrop=0.1, @@ -168,11 +179,6 @@ n_head=12, n_embd=768, vocab_size_up=512, - vocab_size_down=512 - ), + vocab_size_down=512), n_music=438, - n_music_emb=768 - ) - ) -) - \ No newline at end of file + n_music_emb=768))) diff --git a/configs/dance_rev.py b/configs/dance_rev.py index 19c3674..0255d85 100644 --- a/configs/dance_rev.py +++ b/configs/dance_rev.py @@ -1,12 +1,8 @@ -from email import policy import os from datetime import datetime - - num_gpus = 1 - method = 'dance revolution' phase = 'train' @@ -20,21 +16,20 @@ lr_config = dict(policy='step', step=[4, 6], gamma=0.1, by_epoch=True) checkpoint_config = dict(interval=1, by_epoch=True) log_level = 'INFO' -log_config = dict(interval=10, by_epoch=False, hooks=[dict(type='TextLoggerHook')]) +log_config = dict( + interval=10, by_epoch=False, hooks=[dict(type='TextLoggerHook')]) workflow = [('train', 1), ('val', 1)] # workflow = [('val', 1)] # hooks -# 'params' are numeric type value, 'variables' are variables in local environment +# 'params' are numeric type value, +# 'variables' are variables in local environment train_hooks = [ - dict(type='PassEpochNumberToModelHook', - params=dict()), - dict(type='SaveDancePKLHook', - params=dict()), + dict(type='PassEpochNumberToModelHook', params=dict()), + dict(type='SaveDancePKLHook', params=dict()), ] test_hooks = [ - dict(type='SaveTestDancePKLHook', - params=dict(save_folder='test')), + dict(type='SaveTestDancePKLHook', params=dict(save_folder='test')), ] # runner @@ -45,31 +40,33 @@ num_gpus = 1 distributed = 0 # multi-gpu work_dir = './dance_rev/'.format(phase) # noqa -timestamp = datetime.now().strftime("%d-%b-%H-%M") +timestamp = datetime.now().strftime('%d-%b-%H-%M') +# dataset -## dataset - -traindata_cfg = dict( +traindata_cfg = dict( data_dir='data/aistpp_train_wav', rotmat=False, seq_len=240, mode='train', - move=1 -) + move=1) -testdata_cfg = dict( - data_dir='data/aistpp_test_full_wav', - rotmat=False, - mode='test', - move=1 -) +testdata_cfg = dict( + data_dir='data/aistpp_test_full_wav', rotmat=False, mode='test', move=1) train_pipeline = [ - dict(type='ToTensor', enable=True, keys=['music', 'dance'],), + dict( + type='ToTensor', + enable=True, + keys=['music', 'dance'], + ), ] test_pipeline = [ - dict(type='ToTensor', enable=True, keys=['music', 'dance'],), + dict( + type='ToTensor', + enable=True, + keys=['music', 'dance'], + ), ] data = dict( @@ -94,11 +91,11 @@ ) load_from = os.path.join(work_dir, 'epoch_15.pth') -##### model +# model model = dict( type='DanceRevolution', model_config=dict( - #ChoreoGrapher Configs + # ChoreoGrapher Configs max_seq_len=4500, d_frame_vec=438, frame_emb_size=200, @@ -115,7 +112,4 @@ sliding_windown_size=100, lambda_v=0.01, cuda=True, - rotmat=False - ) -) - + rotmat=False)) diff --git a/configs/dance_rev_test.py b/configs/dance_rev_test.py index 1d0f0ed..a445248 100644 --- a/configs/dance_rev_test.py +++ b/configs/dance_rev_test.py @@ -1,12 +1,8 @@ -from email import policy import os from datetime import datetime - - num_gpus = 1 - method = 'dance revolution' phase = 'test' @@ -20,21 +16,20 @@ lr_config = dict(policy='step', step=[4, 6], gamma=0.1, by_epoch=True) checkpoint_config = dict(interval=1, by_epoch=True) log_level = 'INFO' -log_config = dict(interval=10, by_epoch=False, hooks=[dict(type='TextLoggerHook')]) +log_config = dict( + interval=10, by_epoch=False, hooks=[dict(type='TextLoggerHook')]) workflow = [('train', 1), ('val', 1)] # workflow = [('val', 1)] # hooks -# 'params' are numeric type value, 'variables' are variables in local environment +# 'params' are numeric type value, +# 'variables' are variables in local environment train_hooks = [ - dict(type='PassEpochNumberToModelHook', - params=dict()), - dict(type='SaveDancePKLHook', - params=dict()), + dict(type='PassEpochNumberToModelHook', params=dict()), + dict(type='SaveDancePKLHook', params=dict()), ] test_hooks = [ - dict(type='SaveTestDancePKLHook', - params=dict(save_folder='test')), + dict(type='SaveTestDancePKLHook', params=dict(save_folder='test')), ] # runner @@ -45,31 +40,33 @@ num_gpus = 1 distributed = 0 # multi-gpu work_dir = './dance_rev_test/'.format(phase) # noqa -timestamp = datetime.now().strftime("%d-%b-%H-%M") +timestamp = datetime.now().strftime('%d-%b-%H-%M') +# dataset -## dataset - -traindata_cfg = dict( +traindata_cfg = dict( data_dir='data/aistpp_train_wav', rotmat=False, seq_len=240, mode='train', - move=1 -) + move=1) -testdata_cfg = dict( - data_dir='data/aistpp_test_full_wav', - rotmat=False, - mode='test', - move=1 -) +testdata_cfg = dict( + data_dir='data/aistpp_test_full_wav', rotmat=False, mode='test', move=1) train_pipeline = [ - dict(type='ToTensor', enable=True, keys=['music', 'dance'],), + dict( + type='ToTensor', + enable=True, + keys=['music', 'dance'], + ), ] test_pipeline = [ - dict(type='ToTensor', enable=True, keys=['music', 'dance'],), + dict( + type='ToTensor', + enable=True, + keys=['music', 'dance'], + ), ] data = dict( @@ -94,11 +91,11 @@ ) load_from = os.path.join('./example/dance_revolution.pth') -##### model +# model model = dict( type='DanceRevolution', model_config=dict( - #ChoreoGrapher Configs + # ChoreoGrapher Configs max_seq_len=4500, d_frame_vec=438, frame_emb_size=200, @@ -115,7 +112,4 @@ sliding_windown_size=100, lambda_v=0.01, cuda=True, - rotmat=False - ) -) - + rotmat=False)) diff --git a/docs/en/apis.md b/docs/en/apis.md index 6d12e61..6b77b1a 100755 --- a/docs/en/apis.md +++ b/docs/en/apis.md @@ -18,4 +18,4 @@ Purpose: Generate a model for test actions based on running parameters ## parse_args Input: args, run arguments -Purpose: Convert running parameters to mmcv.Config \ No newline at end of file +Purpose: Convert running parameters to mmcv.Config diff --git a/docs/en/changelog.md b/docs/en/changelog.md index 4e768b5..792d600 100755 --- a/docs/en/changelog.md +++ b/docs/en/changelog.md @@ -1 +1 @@ -# \ No newline at end of file +# diff --git a/docs/en/dataset_preparation.md b/docs/en/dataset_preparation.md index 29bb31d..73ab7fc 100755 --- a/docs/en/dataset_preparation.md +++ b/docs/en/dataset_preparation.md @@ -33,5 +33,3 @@ xrmogen │ ├── musics ├── ... ``` - - diff --git a/docs/en/faq.md b/docs/en/faq.md index a3f574b..a6738c3 100755 --- a/docs/en/faq.md +++ b/docs/en/faq.md @@ -1,4 +1,3 @@ # FAQ ## Outline - diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 346c86b..888dc60 100755 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -58,7 +58,7 @@ To implement a new method, your model need to contain following functions/medhot - `train_step()`: forward method of the training mode. - `val_step()`: forward method of the testing mode. -- regestered as a dance model +- registered as a dance model To be specific, if we want to implement a new model, there are several things to do. @@ -74,7 +74,7 @@ To be specific, if we want to implement a new model, there are several things to def __init__(self, model_config): super().__init__() - + def forward(self, ...): .... @@ -108,7 +108,7 @@ To be specific, if we want to implement a new model, there are several things to XRMoGen uses `mmcv.runner.EpochBasedRunner` to control training and test. -In the training mode, the `max_epochs` in config file decide how many epochs to train. +In the training mode, the `max_epochs` in config file decide how many epochs to train. In test mode, `max_epochs` is forced to change to 1, which represents only 1 epoch to test. Validation frequency is set as `workflow` of config file: @@ -120,7 +120,7 @@ Validation frequency is set as `workflow` of config file: For example, to train Bailando (Motion VQVAE phase), ```shell -python main.py --config configs/config/bailando_motion_vqvae.py +python main.py --config configs/config/bailando_motion_vqvae.py ``` Arguments are: diff --git a/docs/en/installation.md b/docs/en/installation.md index 2aca1e6..14f9bcd 100755 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -11,14 +11,14 @@ We provide some tips for XRMoGen installation in this file. - [Prepare environment](#prepare-environment) - [a. Create a conda virtual environment and activate it.](#a-create-a-conda-virtual-environment-and-activate-it) - [b. Install PyTorch and torchvision](#b-install-pytorch-and-torchvision) - + - [c. Install MMHuman3D](#c-install-mmhuman3d) - [d. Install Other Needed Python Packages](#d-install-other-needed-python-packages) - [Another option: Docker Image](#another-option-docker-image) - - + @@ -108,7 +108,7 @@ where [DOCKER_ID] is the docker id that can be obtained by ``` docker ps -a ``` - + - diff --git a/docs/en/pretrained_model_list.md b/docs/en/pretrained_model_list.md index 5e1bc73..2fab4f3 100644 --- a/docs/en/pretrained_model_list.md +++ b/docs/en/pretrained_model_list.md @@ -1,8 +1,7 @@ -# Pretrained model weights +# Pretrained model weights | model | weight link | -| :----------------: | :------------------------: | +| :----------------: | :------------------------: | | Bailando | [bailando.pth](https://openxrlab-share.oss-cn-hongkong.aliyuncs.com/xrmogen/weights/bailando.pth) | | DanceRevolution | [dance_revolution.pth](https://openxrlab-share.oss-cn-hongkong.aliyuncs.com/xrmogen/weights/dance_revolution.pth) | - diff --git a/docs/en/tutorials/config.md b/docs/en/tutorials/config.md index 639b4f3..910c0b6 100755 --- a/docs/en/tutorials/config.md +++ b/docs/en/tutorials/config.md @@ -1,7 +1,7 @@ # Tutorial 1: How to write a config file -In XRMoGen, configuration (config) files are implemented in python. +In XRMoGen, configuration (config) files are implemented in python. A config file contains the configuration required for all experiments, including training and testing pipelines, model, dataset, and other hyperparameters. All configuration files provided by XRMoGen are under the `$PROJECT/configs` folder. @@ -48,7 +48,7 @@ Let's take the training config of the Bailando model as an example: ``` Under the mmcv framework, IOs of training and test, like transmitting information to the model outside the standard dataloader, or storing network results, need to be implemented through hooks. - Required hooks are decalared in config + Required hooks are declared in config ```python train_hooks = [ @@ -89,7 +89,7 @@ Let's take the training config of the Bailando model as an example: type='Bailando', model_config=dict( bailando_phase='motion vqvae', - vqvae=dict( + vqvae=dict( up_half=dict( levels=1, downs_t=[3,], @@ -155,7 +155,7 @@ Let's take the training config of the Bailando model as an example: n_head=12, n_embd=768, vocab_size_up=512, - vocab_size_down=512 + vocab_size_down=512 ), n_music=438, n_music_emb=768 @@ -164,10 +164,10 @@ Let's take the training config of the Bailando model as an example: ) ``` -* Data: +* Data: The data part defines the data set type, data processing flow, batch size and other information. - ```python - traindata_cfg = dict( + ```python + traindata_cfg = dict( data_dir='/mnt/lustre/syli/dance/Bailando/data/aistpp_train_wav', rotmat=False, seq_len=240, @@ -177,7 +177,7 @@ Let's take the training config of the Bailando model as an example: external_wav_rate=8 ) - testdata_cfg = dict( + testdata_cfg = dict( data_dir='/mnt/lustre/syli/dance/Bailando/data/aistpp_test_full_wav', rotmat=False, mode='test', diff --git a/docs/en/tutorials/model.md b/docs/en/tutorials/model.md index 719d86b..aef1c8d 100755 --- a/docs/en/tutorials/model.md +++ b/docs/en/tutorials/model.md @@ -44,4 +44,4 @@ Output of validation is a dictionary, where `output_pose` is generated dance wit } ``` -Output pose will be stored in `.pkl` format after validation. \ No newline at end of file +Output pose will be stored in `.pkl` format after validation. diff --git a/docs/zh_cn/changelog.md b/docs/zh_cn/changelog.md index 4e768b5..792d600 100755 --- a/docs/zh_cn/changelog.md +++ b/docs/zh_cn/changelog.md @@ -1 +1 @@ -# \ No newline at end of file +# diff --git a/docs/zh_cn/dataset_preparation.md b/docs/zh_cn/dataset_preparation.md index c2e55aa..76ae4f4 100755 --- a/docs/zh_cn/dataset_preparation.md +++ b/docs/zh_cn/dataset_preparation.md @@ -33,5 +33,3 @@ xrmogen │ ├── musics ├── ... ``` - - diff --git a/docs/zh_cn/faq.md b/docs/zh_cn/faq.md index 2081d86..a6738c3 100755 --- a/docs/zh_cn/faq.md +++ b/docs/zh_cn/faq.md @@ -1,5 +1,3 @@ # FAQ ## Outline - - diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index 1a28dcc..c965a56 100755 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -75,7 +75,7 @@ xrmogen def __init__(self, model_config): super().__init__() - + def forward(self, ...): .... @@ -121,7 +121,7 @@ XRMoGen 使用 `mmcv.runner.EpochBasedRunner` (以epoch为单位)去训练 比如,为了训练Bailando模型 (Motion VQVAE phase),运行以下命令 ```shell -python main.py --config configs/config/bailando_motion_vqvae.py +python main.py --config configs/config/bailando_motion_vqvae.py ``` 参数: diff --git a/docs/zh_cn/installation.md b/docs/zh_cn/installation.md index 943bf72..58a2e2a 100755 --- a/docs/zh_cn/installation.md +++ b/docs/zh_cn/installation.md @@ -16,7 +16,7 @@ - - + @@ -107,7 +107,7 @@ docker cp ProjectPath/xrmogen [DOCKER_ID]:/workspace ``` docker ps -a ``` - + - diff --git a/docs/zh_cn/pretrained_model_list.md b/docs/zh_cn/pretrained_model_list.md index 24be105..bad2224 100644 --- a/docs/zh_cn/pretrained_model_list.md +++ b/docs/zh_cn/pretrained_model_list.md @@ -1,8 +1,7 @@ # 预训练模型 | 模型 | 预训练权重链接 | -| :----------------: | :------------------------: | +| :----------------: | :------------------------: | | Bailando | [bailando.pth](https://openxrlab-share.oss-cn-hongkong.aliyuncs.com/xrmogen/weights/bailando.pth) | | DanceRevolution | [dance_revolution.pth](https://openxrlab-share.oss-cn-hongkong.aliyuncs.com/xrmogen/weights/dance_revolution.pth) | - diff --git a/docs/zh_cn/tutorials/config.md b/docs/zh_cn/tutorials/config.md index 53e8c6d..757e62e 100755 --- a/docs/zh_cn/tutorials/config.md +++ b/docs/zh_cn/tutorials/config.md @@ -85,7 +85,7 @@ XRMoGen 提供的所有配置文件都放置在 `$PROJECT/configs` 文件夹下 type='Bailando', model_config=dict( bailando_phase='motion vqvae', - vqvae=dict( + vqvae=dict( up_half=dict( levels=1, downs_t=[3,], @@ -151,7 +151,7 @@ XRMoGen 提供的所有配置文件都放置在 `$PROJECT/configs` 文件夹下 n_head=12, n_embd=768, vocab_size_up=512, - vocab_size_down=512 + vocab_size_down=512 ), n_music=438, n_music_emb=768 @@ -162,8 +162,8 @@ XRMoGen 提供的所有配置文件都放置在 `$PROJECT/configs` 文件夹下 * 数据 数据部分的配置信息,定义了数据集类型,数据的处理流程,batchsize等等信息。 - ```python - traindata_cfg = dict( + ```python + traindata_cfg = dict( data_dir='/mnt/lustre/syli/dance/Bailando/data/aistpp_train_wav', rotmat=False, seq_len=240, @@ -173,7 +173,7 @@ XRMoGen 提供的所有配置文件都放置在 `$PROJECT/configs` 文件夹下 external_wav_rate=8 ) - testdata_cfg = dict( + testdata_cfg = dict( data_dir='/mnt/lustre/syli/dance/Bailando/data/aistpp_test_full_wav', rotmat=False, mode='test', diff --git a/docs/zh_cn/tutorials/model.md b/docs/zh_cn/tutorials/model.md index 719d86b..aef1c8d 100755 --- a/docs/zh_cn/tutorials/model.md +++ b/docs/zh_cn/tutorials/model.md @@ -44,4 +44,4 @@ Output of validation is a dictionary, where `output_pose` is generated dance wit } ``` -Output pose will be stored in `.pkl` format after validation. \ No newline at end of file +Output pose will be stored in `.pkl` format after validation. diff --git a/main.py b/main.py index bc89fa4..ede33f6 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from xrmogen.core.apis import * +from xrmogen.core.apis import parse_args, run_mogen if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index 6da8782..a974afd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ -imageio==2.15.0 -mmcv==1.6.1 -numpy -opencv_python -Pillow -scipy -xrprimer +imageio==2.15.0 +mmcv==1.6.1 +numpy +opencv_python +Pillow +scipy tqdm - +xrprimer diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8ede64a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,20 @@ +[bdist_wheel] +universal=1 + +[aliases] +test=pytest + +[yapf] +based_on_style = pep8 +blank_line_before_nested_class_or_def = true +split_before_expression_after_opening_paren = true + +[isort] +line_length = 79 +multi_line_output = 5 +include_trailing_comma = true +known_standard_library = pkg_resources,setuptools +known_first_party = xrmogen +known_third_party =mmcv,numpy,scipy,torch,tqdm,utils,xrprimer +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY diff --git a/srun_test.sh b/srun_test.sh index a1e919f..1f5e573 100644 --- a/srun_test.sh +++ b/srun_test.sh @@ -15,4 +15,4 @@ SCRIPT1="main.py" PYTHON_SCRIPT1="$PYTHON $SCRIPT1 --config $CONFIG --test_only" echo "$SRUN $PYTHON_SCRIPT1" -$SRUN $PYTHON_SCRIPT1 \ No newline at end of file +$SRUN $PYTHON_SCRIPT1 diff --git a/srun_train.sh b/srun_train.sh index d4751a1..2d462f4 100644 --- a/srun_train.sh +++ b/srun_train.sh @@ -15,4 +15,4 @@ SCRIPT1="main.py" PYTHON_SCRIPT1="$PYTHON $SCRIPT1 --config $CONFIG " echo "$SRUN $PYTHON_SCRIPT1" -$SRUN $PYTHON_SCRIPT1 \ No newline at end of file +$SRUN $PYTHON_SCRIPT1 diff --git a/tools/eval_quantitative_scores.py b/tools/eval_quantitative_scores.py index 408c2b2..6858c1d 100644 --- a/tools/eval_quantitative_scores.py +++ b/tools/eval_quantitative_scores.py @@ -1,61 +1,79 @@ -import numpy as np -import pickle -from utils.dance_features.kinetic import extract_kinetic_features -from utils.dance_features.manual import extract_manual_features -from scipy import linalg import argparse import mmcv -from scipy.ndimage import gaussian_filter as G +import numpy as np +import os +from scipy import linalg +from scipy.ndimage import gaussian_filter as G from scipy.signal import argrelextrema +from utils.dance_features.kinetic import extract_kinetic_features +from utils.dance_features.manual import extract_manual_features + -import os def normalize(feat, feat2): mean = feat.mean(axis=0) std = feat.std(axis=0) - + return (feat - mean) / (std + 1e-10), (feat2 - mean) / (std + 1e-10) -def calc_motion_quality(predicted_pkl_root, gt_pkl_root): +def calc_motion_quality(predicted_pkl_root, gt_pkl_root): pred_features_k = [] pred_features_m = [] gt_freatures_k = [] gt_freatures_m = [] - pred_features_k = [mmcv.load(os.path.join(predicted_pkl_root, 'kinetic_features', pkl)) for pkl in os.listdir(os.path.join(predicted_pkl_root, 'kinetic_features')) if pkl.endswith('.pkl')] - pred_features_m = [mmcv.load(os.path.join(predicted_pkl_root, 'manual_features_new', pkl)) for pkl in os.listdir(os.path.join(predicted_pkl_root, 'manual_features_new')) if pkl.endswith('.pkl')] - - gt_freatures_k = [np.load(os.path.join(gt_pkl_root, 'kinetic_features', pkl)) for pkl in os.listdir(os.path.join(gt_pkl_root, 'kinetic_features')) ] - gt_freatures_m = [np.load(os.path.join(gt_pkl_root, 'manual_features_new', pkl)) for pkl in os.listdir(os.path.join(gt_pkl_root, 'manual_features_new')) ] - - - pred_features_k = np.stack(pred_features_k) # Nx72 p40 - pred_features_m = np.stack(pred_features_m) # Nx32 - gt_freatures_k = np.stack(gt_freatures_k) # N' x 72 N' >> N - gt_freatures_m = np.stack(gt_freatures_m) # + pred_features_k = [ + mmcv.load(os.path.join(predicted_pkl_root, 'kinetic_features', pkl)) + for pkl in os.listdir( + os.path.join(predicted_pkl_root, 'kinetic_features')) + if pkl.endswith('.pkl') + ] + pred_features_m = [ + mmcv.load( + os.path.join(predicted_pkl_root, 'manual_features_new', pkl)) + for pkl in os.listdir( + os.path.join(predicted_pkl_root, 'manual_features_new')) + if pkl.endswith('.pkl') + ] + + gt_freatures_k = [ + np.load(os.path.join(gt_pkl_root, 'kinetic_features', pkl)) + for pkl in os.listdir(os.path.join(gt_pkl_root, 'kinetic_features')) + ] + gt_freatures_m = [ + np.load(os.path.join(gt_pkl_root, 'manual_features_new', pkl)) + for pkl in os.listdir( + os.path.join(gt_pkl_root, 'manual_features_new')) + ] + pred_features_k = np.stack(pred_features_k) # Nx72 p40 + pred_features_m = np.stack(pred_features_m) # Nx32 + gt_freatures_k = np.stack(gt_freatures_k) # N' x 72 N' >> N + gt_freatures_m = np.stack(gt_freatures_m) # - gt_freatures_k, pred_features_k = normalize(gt_freatures_k, pred_features_k) - gt_freatures_m, pred_features_m = normalize(gt_freatures_m, pred_features_m) - + gt_freatures_k, pred_features_k = normalize(gt_freatures_k, + pred_features_k) + gt_freatures_m, pred_features_m = normalize(gt_freatures_m, + pred_features_m) fid_k = calc_fid(pred_features_k, gt_freatures_k) fid_m = calc_fid(pred_features_m, gt_freatures_m) - div_k = calculate_avg_distance(pred_features_k) div_m = calculate_avg_distance(pred_features_m) - - metrics = {'FIDk': fid_k.real, 'FIDg': fid_m.real, 'DIVk': div_k, 'DIVg' : div_m} + metrics = { + 'FIDk': fid_k.real, + 'FIDg': fid_m.real, + 'DIVk': div_k, + 'DIVg': div_m + } return metrics def calc_fid(kps_gen, kps_gt): - """ - compute FID between features of generated dance and GT - """ + """compute FID between features of generated dance and GT.""" mu_gen = np.mean(kps_gen, axis=0) sigma_gen = np.cov(kps_gen, rowvar=False) @@ -63,7 +81,7 @@ def calc_fid(kps_gen, kps_gt): mu_gt = np.mean(kps_gt, axis=0) sigma_gt = np.cov(kps_gt, rowvar=False) - mu1,mu2,sigma1,sigma2 = mu_gen, mu_gt, sigma_gen, sigma_gt + mu1, mu2, sigma1, sigma2 = mu_gen, mu_gt, sigma_gen, sigma_gt diff = mu1 - mu2 eps = 1e-5 @@ -79,21 +97,22 @@ def calc_fid(kps_gen, kps_gt): # Numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): - m = np.max(np.abs(covmean.imag)) + # m = np.max(np.abs(covmean.imag)) # raise ValueError('Imaginary component {}'.format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) - return (diff.dot(diff) + np.trace(sigma1) - + np.trace(sigma2) - 2 * tr_covmean) + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - + 2 * tr_covmean) def calc_diversity(feats): feat_array = np.array(feats) n, c = feat_array.shape diff = np.array([feat_array] * n) - feat_array.reshape(n, 1, c) - return np.sqrt(np.sum(diff**2, axis=2)).sum() / n / (n-1) + return np.sqrt(np.sum(diff**2, axis=2)).sum() / n / (n - 1) + def calculate_avg_distance(feature_list, mean=None, std=None): feature_list = np.stack(feature_list) @@ -108,6 +127,7 @@ def calculate_avg_distance(feature_list, mean=None, std=None): dist /= (n * n - n) / 2 return dist + def calc_and_save_feats(root, start=0, end=1200): """ compute and save motion features @@ -123,22 +143,33 @@ def calc_and_save_feats(root, start=0, end=1200): # print(root) for pkl in os.listdir(root): print(pkl) - if (os.path.exists(os.path.join(root, 'kinetic_features', pkl)) and os.path.exists(os.path.join(root, 'manual_features_new', pkl))) or os.path.isdir(os.path.join(root, pkl)): + if (os.path.exists(os.path.join(root, 'kinetic_features', pkl)) and + os.path.exists(os.path.join(root, 'manual_features_new', + pkl))) or os.path.isdir( + os.path.join(root, pkl)): continue - joint3d = mmcv.load(os.path.join(root, pkl)).reshape(-1, 72)[start:end,:] + joint3d = mmcv.load(os.path.join(root, pkl)).reshape(-1, + 72)[start:end, :] roott = joint3d[:1, :3] # the root Tx72 (Tx(24x3)) - joint3d = joint3d - np.tile(roott, (1, 24)) # Calculate relative offset with respect to root + joint3d = joint3d - np.tile( + roott, (1, 24)) # Calculate relative offset with respect to root - mmcv.dump(extract_kinetic_features(joint3d.reshape(-1, 24, 3)), os.path.join(root, 'kinetic_features', pkl)) - mmcv.dump(extract_manual_features(joint3d.reshape(-1, 24, 3)), os.path.join(root, 'manual_features_new', pkl)) + mmcv.dump( + extract_kinetic_features(joint3d.reshape(-1, 24, 3)), + os.path.join(root, 'kinetic_features', pkl)) + mmcv.dump( + extract_manual_features(joint3d.reshape(-1, 24, 3)), + os.path.join(root, 'manual_features_new', pkl)) def get_music_beat(music_feature_root, key, length=None): """ - Fetch music beats from preprocessed music features, represented as bool (True=beats) + Fetch music beats from preprocessed music features, + represented as bool (True=beats) Args: - music_feature_root: the root folder of preprocessed music features + music_feature_root: the root folder of + preprocessed music features key: dance name length: restriction on sample length """ @@ -151,13 +182,14 @@ def get_music_beat(music_feature_root, key, length=None): beats = beats.astype(bool) beat_axis = np.arange(len(beats)) beat_axis = beat_axis[beats] - + return beat_axis def calc_dance_beat(keypoints): keypoints = np.array(keypoints).reshape(-1, 24, 3) - kinetic_vel = np.mean(np.sqrt(np.sum((keypoints[1:] - keypoints[:-1]) ** 2, axis=2)), axis=1) + kinetic_vel = np.mean( + np.sqrt(np.sum((keypoints[1:] - keypoints[:-1])**2, axis=2)), axis=1) kinetic_vel = G(kinetic_vel, 5) motion_beats = argrelextrema(kinetic_vel, np.less) return motion_beats, len(kinetic_vel) @@ -166,7 +198,7 @@ def calc_dance_beat(keypoints): def beat_align_score(music_beats, motion_beats): ba = 0 for bb in music_beats: - ba += np.exp(-np.min((motion_beats[0] - bb)**2) / 2 / 9) + ba += np.exp(-np.min((motion_beats[0] - bb)**2) / 2 / 9) return (ba / len(music_beats)) @@ -178,13 +210,13 @@ def calc_beat_align_score(pkl_root, music_feature_root): continue joint3d = mmcv.load(os.path.join(pkl_root, pkl)) - dance_beats, length = calc_dance_beat(joint3d) - music_beats = get_music_beat(music_feature_root, pkl.split('.')[0] + '.json', length) + dance_beats, length = calc_dance_beat(joint3d) + music_beats = get_music_beat(music_feature_root, + pkl.split('.')[0] + '.json', length) ba_scores.append(beat_align_score(music_beats, dance_beats)) - - return np.mean(ba_scores) + return np.mean(ba_scores) if __name__ == '__main__': @@ -204,8 +236,11 @@ def calc_beat_align_score(pkl_root, music_feature_root): # music-beat align score print('Calculating Music-dance beat alignment score') - metrics.update(dict(BeatAlignScore=calc_beat_align_score(args.pkl_root, args.music_feature_root))) - + metrics.update( + dict( + BeatAlignScore=calc_beat_align_score(args.pkl_root, + args.music_feature_root))) + print('Quantitative scores:', metrics) print(metrics) mmcv.dump(metrics, args.pkl_root + '_scores.json') diff --git a/tools/utils/dance_features/kinetic.py b/tools/utils/dance_features/kinetic.py index 7fdabc8..a6346c6 100644 --- a/tools/utils/dance_features/kinetic.py +++ b/tools/utils/dance_features/kinetic.py @@ -5,55 +5,72 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Modified by Ruilong Li -# Redistribution and use in source and binary forms, with or without modification, +# Redistribution and use in source and binary forms, +# with or without modification, # are permitted provided that the following conditions are met: -# * Redistributions of source code must retain the above copyright notice, this +# * Redistributions of source code must retain the above +# copyright notice, this # list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation +# * Redistributions in binary form must reproduce the +# above copyright notice, +# this list of conditions and the following disclaimer +# in the documentation # and/or other materials provided with the distribution. -# * Neither the name Facebook nor the names of its contributors may be used to -# endorse or promote products derived from this software without specific +# * Neither the name Facebook nor the names of +# its contributors may be used to +# endorse or promote products derived from +# this software without specific # prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT +# HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, +# BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np + from . import utils as feat_utils def extract_kinetic_features(positions): - assert len(positions.shape) == 3 # (seq_len, n_joints, 3) + assert len(positions.shape) == 3 # (seq_len, n_joints, 3) features = KineticFeatures(positions) kinetic_feature_vector = [] for i in range(positions.shape[1]): - feature_vector = np.hstack( - [ - features.average_kinetic_energy_horizontal(i), - features.average_kinetic_energy_vertical(i), - features.average_energy_expenditure(i), - ] - ) + feature_vector = np.hstack([ + features.average_kinetic_energy_horizontal(i), + features.average_kinetic_energy_vertical(i), + features.average_energy_expenditure(i), + ]) kinetic_feature_vector.extend(feature_vector) kinetic_feature_vector = np.array(kinetic_feature_vector, dtype=np.float32) return kinetic_feature_vector class KineticFeatures: - def __init__( - self, positions, frame_time=1./60, up_vec="y", sliding_window=2 - ): + + def __init__(self, + positions, + frame_time=1. / 60, + up_vec='y', + sliding_window=2): self.positions = positions self.frame_time = frame_time self.up_vec = up_vec @@ -63,12 +80,10 @@ def average_kinetic_energy(self, joint): average_kinetic_energy = 0 for i in range(1, len(self.positions)): average_velocity = feat_utils.calc_average_velocity( - self.positions, i, joint, self.sliding_window, self.frame_time - ) - average_kinetic_energy += average_velocity ** 2 + self.positions, i, joint, self.sliding_window, self.frame_time) + average_kinetic_energy += average_velocity**2 average_kinetic_energy = average_kinetic_energy / ( - len(self.positions) - 1.0 - ) + len(self.positions) - 1.0) return average_kinetic_energy def average_kinetic_energy_horizontal(self, joint): @@ -82,7 +97,7 @@ def average_kinetic_energy_horizontal(self, joint): self.frame_time, self.up_vec, ) - val += average_velocity ** 2 + val += average_velocity**2 val = val / (len(self.positions) - 1.0) return val @@ -97,7 +112,7 @@ def average_kinetic_energy_vertical(self, joint): self.frame_time, self.up_vec, ) - val += average_velocity ** 2 + val += average_velocity**2 val = val / (len(self.positions) - 1.0) return val @@ -105,7 +120,6 @@ def average_energy_expenditure(self, joint): val = 0.0 for i in range(1, len(self.positions)): val += feat_utils.calc_average_acceleration( - self.positions, i, joint, self.sliding_window, self.frame_time - ) + self.positions, i, joint, self.sliding_window, self.frame_time) val = val / (len(self.positions) - 1.0) return val diff --git a/tools/utils/dance_features/manual.py b/tools/utils/dance_features/manual.py index 2e4edb8..a09eda7 100644 --- a/tools/utils/dance_features/manual.py +++ b/tools/utils/dance_features/manual.py @@ -5,142 +5,144 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Modified by Ruilong Li -# Redistribution and use in source and binary forms, with or without modification, -# are permitted provided that the following conditions are met: +# Redistribution and use in source and binary forms, +# with or without modification, +# are permitted provided that +# the following conditions are met: -# * Redistributions of source code must retain the above copyright notice, this +# * Redistributions of source code +# must retain the above copyright notice, this # list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation +# * Redistributions in binary form must reproduce +# the above copyright notice, +# this list of conditions and the following +# disclaimer in the documentation # and/or other materials provided with the distribution. -# * Neither the name Facebook nor the names of its contributors may be used to -# endorse or promote products derived from this software without specific +# * Neither the name Facebook nor the names of +# its contributors may be used to +# endorse or promote products derived from +# this software without specific # prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT +# HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, +# BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np -from . import utils as feat_utils +from . import utils as feat_utils SMPL_JOINT_NAMES = [ - "root", - "lhip", "rhip", "belly", - "lknee", "rknee", "spine", - "lankle", "rankle", "chest", - "ltoes", "rtoes", "neck", - "linshoulder", "rinshoulder", - "head", "lshoulder", "rshoulder", - "lelbow", "relbow", - "lwrist", "rwrist", - "lhand", "rhand", + 'root', + 'lhip', + 'rhip', + 'belly', + 'lknee', + 'rknee', + 'spine', + 'lankle', + 'rankle', + 'chest', + 'ltoes', + 'rtoes', + 'neck', + 'linshoulder', + 'rinshoulder', + 'head', + 'lshoulder', + 'rshoulder', + 'lelbow', + 'relbow', + 'lwrist', + 'rwrist', + 'lhand', + 'rhand', ] def extract_manual_features(positions): - assert len(positions.shape) == 3 # (seq_len, n_joints, 3) + assert len(positions.shape) == 3 # (seq_len, n_joints, 3) features = [] f = ManualFeatures(positions) for _ in range(1, positions.shape[0]): pose_features = [] pose_features.append( - f.f_nmove("neck", "rhip", "lhip", "rwrist", 1.8 * f.hl) - ) + f.f_nmove('neck', 'rhip', 'lhip', 'rwrist', 1.8 * f.hl)) pose_features.append( - f.f_nmove("neck", "lhip", "rhip", "lwrist", 1.8 * f.hl) - ) + f.f_nmove('neck', 'lhip', 'rhip', 'lwrist', 1.8 * f.hl)) pose_features.append( - f.f_nplane("chest", "neck", "neck", "rwrist", 0.2 * f.hl) - ) + f.f_nplane('chest', 'neck', 'neck', 'rwrist', 0.2 * f.hl)) pose_features.append( - f.f_nplane("chest", "neck", "neck", "lwrist", 0.2 * f.hl) - ) + f.f_nplane('chest', 'neck', 'neck', 'lwrist', 0.2 * f.hl)) pose_features.append( - f.f_move("belly", "chest", "chest", "rwrist", 1.8 * f.hl) - ) + f.f_move('belly', 'chest', 'chest', 'rwrist', 1.8 * f.hl)) pose_features.append( - f.f_move("belly", "chest", "chest", "lwrist", 1.8 * f.hl) - ) + f.f_move('belly', 'chest', 'chest', 'lwrist', 1.8 * f.hl)) pose_features.append( - f.f_angle("relbow", "rshoulder", "relbow", "rwrist", [0, 110]) - ) + f.f_angle('relbow', 'rshoulder', 'relbow', 'rwrist', [0, 110])) pose_features.append( - f.f_angle("lelbow", "lshoulder", "lelbow", "lwrist", [0, 110]) - ) + f.f_angle('lelbow', 'lshoulder', 'lelbow', 'lwrist', [0, 110])) pose_features.append( - f.f_nplane( - "lshoulder", "rshoulder", "lwrist", "rwrist", 2.5 * f.sw - ) - ) + f.f_nplane('lshoulder', 'rshoulder', 'lwrist', 'rwrist', + 2.5 * f.sw)) pose_features.append( - f.f_move("lwrist", "rwrist", "rwrist", "lwrist", 1.4 * f.hl) - ) + f.f_move('lwrist', 'rwrist', 'rwrist', 'lwrist', 1.4 * f.hl)) pose_features.append( - f.f_move("rwrist", "root", "lwrist", "root", 1.4 * f.hl) - ) + f.f_move('rwrist', 'root', 'lwrist', 'root', 1.4 * f.hl)) pose_features.append( - f.f_move("lwrist", "root", "rwrist", "root", 1.4 * f.hl) - ) - pose_features.append(f.f_fast("rwrist", 2.5 * f.hl)) - pose_features.append(f.f_fast("lwrist", 2.5 * f.hl)) + f.f_move('lwrist', 'root', 'rwrist', 'root', 1.4 * f.hl)) + pose_features.append(f.f_fast('rwrist', 2.5 * f.hl)) + pose_features.append(f.f_fast('lwrist', 2.5 * f.hl)) pose_features.append( - f.f_plane("root", "lhip", "ltoes", "rankle", 0.38 * f.hl) - ) + f.f_plane('root', 'lhip', 'ltoes', 'rankle', 0.38 * f.hl)) pose_features.append( - f.f_plane("root", "rhip", "rtoes", "lankle", 0.38 * f.hl) - ) + f.f_plane('root', 'rhip', 'rtoes', 'lankle', 0.38 * f.hl)) pose_features.append( - f.f_nplane("zero", "y_unit", "y_min", "rankle", 1.2 * f.hl) - ) + f.f_nplane('zero', 'y_unit', 'y_min', 'rankle', 1.2 * f.hl)) pose_features.append( - f.f_nplane("zero", "y_unit", "y_min", "lankle", 1.2 * f.hl) - ) + f.f_nplane('zero', 'y_unit', 'y_min', 'lankle', 1.2 * f.hl)) pose_features.append( - f.f_nplane("lhip", "rhip", "lankle", "rankle", 2.1 * f.hw) - ) + f.f_nplane('lhip', 'rhip', 'lankle', 'rankle', 2.1 * f.hw)) pose_features.append( - f.f_angle("rknee", "rhip", "rknee", "rankle", [0, 110]) - ) + f.f_angle('rknee', 'rhip', 'rknee', 'rankle', [0, 110])) pose_features.append( - f.f_angle("lknee", "lhip", "lknee", "lankle", [0, 110]) - ) - pose_features.append(f.f_fast("rankle", 2.5 * f.hl)) - pose_features.append(f.f_fast("lankle", 2.5 * f.hl)) + f.f_angle('lknee', 'lhip', 'lknee', 'lankle', [0, 110])) + pose_features.append(f.f_fast('rankle', 2.5 * f.hl)) + pose_features.append(f.f_fast('lankle', 2.5 * f.hl)) pose_features.append( - f.f_angle("neck", "root", "rshoulder", "relbow", [25, 180]) - ) + f.f_angle('neck', 'root', 'rshoulder', 'relbow', [25, 180])) pose_features.append( - f.f_angle("neck", "root", "lshoulder", "lelbow", [25, 180]) - ) + f.f_angle('neck', 'root', 'lshoulder', 'lelbow', [25, 180])) pose_features.append( - f.f_angle("neck", "root", "rhip", "rknee", [50, 180]) - ) + f.f_angle('neck', 'root', 'rhip', 'rknee', [50, 180])) pose_features.append( - f.f_angle("neck", "root", "lhip", "lknee", [50, 180]) - ) + f.f_angle('neck', 'root', 'lhip', 'lknee', [50, 180])) pose_features.append( - f.f_plane("rankle", "neck", "lankle", "root", 0.5 * f.hl) - ) + f.f_plane('rankle', 'neck', 'lankle', 'root', 0.5 * f.hl)) pose_features.append( - f.f_angle("neck", "root", "zero", "y_unit", [70, 110]) - ) + f.f_angle('neck', 'root', 'zero', 'y_unit', [70, 110])) pose_features.append( - f.f_nplane("zero", "minus_y_unit", "y_min", "rwrist", -1.2 * f.hl) - ) + f.f_nplane('zero', 'minus_y_unit', 'y_min', 'rwrist', -1.2 * f.hl)) pose_features.append( - f.f_nplane("zero", "minus_y_unit", "y_min", "lwrist", -1.2 * f.hl) - ) - pose_features.append(f.f_fast("root", 2.3 * f.hl)) + f.f_nplane('zero', 'minus_y_unit', 'y_min', 'lwrist', -1.2 * f.hl)) + pose_features.append(f.f_fast('root', 2.3 * f.hl)) features.append(pose_features) f.next_frame() features = np.array(features, dtype=np.float32).mean(axis=0) @@ -148,6 +150,7 @@ def extract_manual_features(positions): class ManualFeatures: + def __init__(self, positions, joint_names=SMPL_JOINT_NAMES): self.positions = positions self.joint_names = joint_names @@ -155,68 +158,73 @@ def __init__(self, positions, joint_names=SMPL_JOINT_NAMES): # humerus length self.hl = feat_utils.distance_between_points( - [1.99113488e-01, 2.36807942e-01, -1.80702247e-02], # "lshoulder", - [4.54445392e-01, 2.21158922e-01, -4.10167128e-02], # "lelbow" + [1.99113488e-01, 2.36807942e-01, -1.80702247e-02], # "lshoulder", + [4.54445392e-01, 2.21158922e-01, -4.10167128e-02], # "lelbow" ) # shoulder width self.sw = feat_utils.distance_between_points( - [1.99113488e-01, 2.36807942e-01, -1.80702247e-02], # "lshoulder" - [-1.91692337e-01, 2.36928746e-01, -1.23055102e-02,], # "rshoulder" + [1.99113488e-01, 2.36807942e-01, -1.80702247e-02], # "lshoulder" + [ + -1.91692337e-01, + 2.36928746e-01, + -1.23055102e-02, + ], # "rshoulder" ) # hip width self.hw = feat_utils.distance_between_points( - [5.64076714e-02, -3.23069185e-01, 1.09197125e-02], # "lhip" - [-6.24834076e-02, -3.31302464e-01, 1.50412619e-02], # "rhip" + [5.64076714e-02, -3.23069185e-01, 1.09197125e-02], # "lhip" + [-6.24834076e-02, -3.31302464e-01, 1.50412619e-02], # "rhip" ) def next_frame(self): self.frame_num += 1 def transform_and_fetch_position(self, j): - if j == "y_unit": + if j == 'y_unit': return [0, 1, 0] - elif j == "minus_y_unit": + elif j == 'minus_y_unit': return [0, -1, 0] - elif j == "zero": + elif j == 'zero': return [0, 0, 0] - elif j == "y_min": + elif j == 'y_min': return [ 0, - min( - [y for (_, y, _) in self.positions[self.frame_num]] - ), + min([y for (_, y, _) in self.positions[self.frame_num]]), 0, ] - return self.positions[self.frame_num][ - self.joint_names.index(j) - ] + return self.positions[self.frame_num][self.joint_names.index(j)] def transform_and_fetch_prev_position(self, j): - return self.positions[self.frame_num - 1][ - self.joint_names.index(j) - ] + return self.positions[self.frame_num - 1][self.joint_names.index(j)] def f_move(self, j1, j2, j3, j4, range): j1_prev, j2_prev, j3_prev, j4_prev = [ - self.transform_and_fetch_prev_position(j) for j in [j1, j2, j3, j4] + self.transform_and_fetch_prev_position(j) + for j in [j1, j2, j3, j4] ] j1, j2, j3, j4 = [ self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] ] return feat_utils.velocity_direction_above_threshold( - j1, j1_prev, j2, j2_prev, j3, j3_prev, range, + j1, + j1_prev, + j2, + j2_prev, + j3, + j3_prev, + range, ) def f_nmove(self, j1, j2, j3, j4, range): j1_prev, j2_prev, j3_prev, j4_prev = [ - self.transform_and_fetch_prev_position(j) for j in [j1, j2, j3, j4] + self.transform_and_fetch_prev_position(j) + for j in [j1, j2, j3, j4] ] j1, j2, j3, j4 = [ self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] ] return feat_utils.velocity_direction_above_threshold_normal( - j1, j1_prev, j2, j3, j4, j4_prev, range - ) + j1, j1_prev, j2, j3, j4, j4_prev, range) def f_plane(self, j1, j2, j3, j4, threshold): j1, j2, j3, j4 = [ @@ -224,7 +232,7 @@ def f_plane(self, j1, j2, j3, j4, threshold): ] return feat_utils.distance_from_plane(j1, j2, j3, j4, threshold) - # + # def f_nplane(self, j1, j2, j3, j4, threshold): j1, j2, j3, j4 = [ self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] @@ -238,7 +246,7 @@ def f_angle(self, j1, j2, j3, j4, range): ] return feat_utils.angle_within_range(j1, j2, j3, j4, range) - # non-relative + # non-relative def f_fast(self, j1, threshold): j1_prev = self.transform_and_fetch_prev_position(j1) j1 = self.transform_and_fetch_position(j1) diff --git a/tools/utils/dance_features/utils.py b/tools/utils/dance_features/utils.py index e4369a8..ab43512 100644 --- a/tools/utils/dance_features/utils.py +++ b/tools/utils/dance_features/utils.py @@ -4,29 +4,44 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# Redistribution and use in source and binary forms, with or without modification, +# Redistribution and use in source and binary forms, +# with or without modification, # are permitted provided that the following conditions are met: -# * Redistributions of source code must retain the above copyright notice, this +# * Redistributions of source code must retain +# the above copyright notice, this # list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation +# * Redistributions in binary form must reproduce +# the above copyright notice, +# this list of conditions and the following +# disclaimer in the documentation # and/or other materials provided with the distribution. -# * Neither the name Facebook nor the names of its contributors may be used to -# endorse or promote products derived from this software without specific +# * Neither the name Facebook nor the names of +# its contributors may be used to +# endorse or promote products derived from +# this software without specific # prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT +# HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, +# BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np @@ -63,34 +78,40 @@ def angle_within_range(j1, j2, k1, k2, range): return False -def velocity_direction_above_threshold( - j1, j1_prev, j2, j2_prev, p, p_prev, threshold, time_per_frame=1 / 120.0 -): +def velocity_direction_above_threshold(j1, + j1_prev, + j2, + j2_prev, + p, + p_prev, + threshold, + time_per_frame=1 / 120.0): velocity = ( - np.array(p) - np.array(j1) - (np.array(p_prev) - np.array(j1_prev)) - ) + np.array(p) - np.array(j1) - (np.array(p_prev) - np.array(j1_prev))) direction = np.array(j2) - np.array(j1) - velocity_along_direction = np.dot(velocity, direction) / np.linalg.norm( - direction - ) + velocity_along_direction = np.dot(velocity, + direction) / np.linalg.norm(direction) velocity_along_direction = velocity_along_direction / time_per_frame return velocity_along_direction > threshold -def velocity_direction_above_threshold_normal( - j1, j1_prev, j2, j3, p, p_prev, threshold, time_per_frame=1 / 120.0 -): +def velocity_direction_above_threshold_normal(j1, + j1_prev, + j2, + j3, + p, + p_prev, + threshold, + time_per_frame=1 / 120.0): velocity = ( - np.array(p) - np.array(j1) - (np.array(p_prev) - np.array(j1_prev)) - ) + np.array(p) - np.array(j1) - (np.array(p_prev) - np.array(j1_prev))) j31 = np.array(j3) - np.array(j1) j21 = np.array(j2) - np.array(j1) direction = np.cross(j31, j21) - velocity_along_direction = np.dot(velocity, direction) / np.linalg.norm( - direction - ) + velocity_along_direction = np.dot(velocity, + direction) / np.linalg.norm(direction) velocity_along_direction = velocity_along_direction / time_per_frame return velocity_along_direction > threshold @@ -107,77 +128,71 @@ def calc_average_velocity(positions, i, joint_idx, sliding_window, frame_time): if i + j - 1 < 0 or i + j >= len(positions): continue average_velocity += ( - positions[i + j][joint_idx] - positions[i + j - 1][joint_idx] - ) + positions[i + j][joint_idx] - positions[i + j - 1][joint_idx]) current_window += 1 return np.linalg.norm(average_velocity / (current_window * frame_time)) -def calc_average_acceleration( - positions, i, joint_idx, sliding_window, frame_time -): +def calc_average_acceleration(positions, i, joint_idx, sliding_window, + frame_time): current_window = 0 average_acceleration = np.zeros(len(positions[0][joint_idx])) for j in range(-sliding_window, sliding_window + 1): if i + j - 1 < 0 or i + j + 1 >= len(positions): continue - v2 = ( - positions[i + j + 1][joint_idx] - positions[i + j][joint_idx] - ) / frame_time + v2 = (positions[i + j + 1][joint_idx] - + positions[i + j][joint_idx]) / frame_time v1 = ( - positions[i + j][joint_idx] - - positions[i + j - 1][joint_idx] / frame_time - ) + positions[i + j][joint_idx] - + positions[i + j - 1][joint_idx] / frame_time) average_acceleration += (v2 - v1) / frame_time current_window += 1 return np.linalg.norm(average_acceleration / current_window) -def calc_average_velocity_horizontal( - positions, i, joint_idx, sliding_window, frame_time, up_vec="z" -): +def calc_average_velocity_horizontal(positions, + i, + joint_idx, + sliding_window, + frame_time, + up_vec='z'): current_window = 0 average_velocity = np.zeros(len(positions[0][joint_idx])) for j in range(-sliding_window, sliding_window + 1): if i + j - 1 < 0 or i + j >= len(positions): continue average_velocity += ( - positions[i + j][joint_idx] - positions[i + j - 1][joint_idx] - ) + positions[i + j][joint_idx] - positions[i + j - 1][joint_idx]) current_window += 1 - if up_vec == "y": - average_velocity = np.array( - [average_velocity[0], average_velocity[2]] - ) / (current_window * frame_time) - elif up_vec == "z": - average_velocity = np.array( - [average_velocity[0], average_velocity[1]] - ) / (current_window * frame_time) + if up_vec == 'y': + average_velocity = np.array([average_velocity[0], average_velocity[2] + ]) / ( + current_window * frame_time) + elif up_vec == 'z': + average_velocity = np.array([average_velocity[0], average_velocity[1] + ]) / ( + current_window * frame_time) else: raise NotImplementedError return np.linalg.norm(average_velocity) -def calc_average_velocity_vertical( - positions, i, joint_idx, sliding_window, frame_time, up_vec -): +def calc_average_velocity_vertical(positions, i, joint_idx, sliding_window, + frame_time, up_vec): current_window = 0 average_velocity = np.zeros(len(positions[0][joint_idx])) for j in range(-sliding_window, sliding_window + 1): if i + j - 1 < 0 or i + j >= len(positions): continue average_velocity += ( - positions[i + j][joint_idx] - positions[i + j - 1][joint_idx] - ) + positions[i + j][joint_idx] - positions[i + j - 1][joint_idx]) current_window += 1 - if up_vec == "y": + if up_vec == 'y': average_velocity = np.array([average_velocity[1]]) / ( - current_window * frame_time - ) - elif up_vec == "z": + current_window * frame_time) + elif up_vec == 'z': average_velocity = np.array([average_velocity[2]]) / ( - current_window * frame_time - ) + current_window * frame_time) else: raise NotImplementedError - return np.linalg.norm(average_velocity) \ No newline at end of file + return np.linalg.norm(average_velocity) diff --git a/tools/visualize_dance_from_pkl.py b/tools/visualize_dance_from_pkl.py index d91ce0c..e2e7745 100644 --- a/tools/visualize_dance_from_pkl.py +++ b/tools/visualize_dance_from_pkl.py @@ -1,78 +1,69 @@ -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this open-source project. - - -from email.mime import audio -import os -import json -import argparse -import numpy as np - -import numpy - -from tqdm import tqdm - -import os -import mmcv - -import numpy as np -from xrmogen.data_structure.keypoints import Keypoints -from xrmogen.core.visualization import visualize_keypoints3d - -def visualizeAndWritefromPKL(pkl_root, audio_path=None): - - video_root= os.path.join(pkl_root, 'video') - if not os.path.exists(video_root): - os.mkdir(video_root) - - music_names = sorted(os.listdir(audio_path)) - - for pkl_name in tqdm(os.listdir(pkl_root), desc='Generating Videos'): - - if not pkl_name.endswith('.pkl'): - continue - print(pkl_name, flush=True) - result = mmcv.load(os.path.join(pkl_root, pkl_name)) - - np_dance = result[0] - print(np_dance.shape) - - kps3d_arr = np_dance.reshape([np_dance.shape[0], 24, 3]) - - kps3d_arr_w_conf = np.concatenate( - (kps3d_arr, np.ones_like(kps3d_arr[..., 0:1])), - axis=-1 - ) - kps3d_arr_w_conf = np.expand_dims(kps3d_arr_w_conf, axis=1) - # mask array in shape (n_frame, n_person, n_kps) - kps3d_mask = np.ones_like(kps3d_arr_w_conf[..., 0]) - convention = 'smpl' - keypoints3d = Keypoints( - kps=kps3d_arr_w_conf, - mask=kps3d_mask, - convention=convention - ) - visualize_keypoints3d( - keypoints=keypoints3d, - output_path=os.path.join(video_root, pkl_name.split('.pkl')[0] + '.mp4') - ) - dance_name = pkl_name.split('.pkl')[0] - - # video + audio - name = dance_name.split(".")[0] - if 'cAll' in name: - music_name = name[-9:-5] + '.wav' - - if music_name in music_names: - audio_dir_ = os.path.join(audio_path, music_name) - name_w_audio = name + "_audio" - cmd_audio = f"ffmpeg -i {video_root}/{name}.mp4 -i {audio_dir_} -map 0:v -map 1:a -c:v copy -shortest -y {video_root}/{name_w_audio}.mp4 -loglevel quiet" - os.system(cmd_audio) - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='visulize from recorded pkl') - parser.add_argument('--pkl_root', type=str) - parser.add_argument('--audio_path', type=str, default='') - args = parser.parse_args() - - visualizeAndWritefromPKL(args.pkl_root, args.audio_path) +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this open-source project. + +import argparse +import mmcv +import numpy as np +import os +from tqdm import tqdm + +from xrmogen.core.visualization import visualize_keypoints3d +from xrmogen.data_structure.keypoints import Keypoints + + +def visualizeAndWritefromPKL(pkl_root, audio_path=None): + + video_root = os.path.join(pkl_root, 'video') + if not os.path.exists(video_root): + os.mkdir(video_root) + + music_names = sorted(os.listdir(audio_path)) + + for pkl_name in tqdm(os.listdir(pkl_root), desc='Generating Videos'): + + if not pkl_name.endswith('.pkl'): + continue + print(pkl_name, flush=True) + result = mmcv.load(os.path.join(pkl_root, pkl_name)) + + np_dance = result[0] + print(np_dance.shape) + + kps3d_arr = np_dance.reshape([np_dance.shape[0], 24, 3]) + + kps3d_arr_w_conf = np.concatenate( + (kps3d_arr, np.ones_like(kps3d_arr[..., 0:1])), axis=-1) + kps3d_arr_w_conf = np.expand_dims(kps3d_arr_w_conf, axis=1) + # mask array in shape (n_frame, n_person, n_kps) + kps3d_mask = np.ones_like(kps3d_arr_w_conf[..., 0]) + convention = 'smpl' + keypoints3d = Keypoints( + kps=kps3d_arr_w_conf, mask=kps3d_mask, convention=convention) + visualize_keypoints3d( + keypoints=keypoints3d, + output_path=os.path.join(video_root, + pkl_name.split('.pkl')[0] + '.mp4')) + dance_name = pkl_name.split('.pkl')[0] + + # video + audio + name = dance_name.split('.')[0] + if 'cAll' in name: + music_name = name[-9:-5] + '.wav' + + if music_name in music_names: + audio_dir_ = os.path.join(audio_path, music_name) + name_w_audio = name + '_audio' + cmd_audio = f'ffmpeg -i {video_root}/{name}.mp4 -i' +\ + f' {audio_dir_} -map 0:v -map 1:a -c:v' +\ + ' copy -shortest -y' +\ + f' {video_root}/{name_w_audio}.mp4 -loglevel quiet' + os.system(cmd_audio) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='visulize from recorded pkl') + parser.add_argument('--pkl_root', type=str) + parser.add_argument('--audio_path', type=str, default='') + args = parser.parse_args() + + visualizeAndWritefromPKL(args.pkl_root, args.audio_path) diff --git a/xrmogen/core/__init__.py b/xrmogen/core/__init__.py index cbcbd47..e69de29 100644 --- a/xrmogen/core/__init__.py +++ b/xrmogen/core/__init__.py @@ -1,3 +0,0 @@ -from .apis import * -from .hooks import * -from .runner import * diff --git a/xrmogen/core/apis/api.py b/xrmogen/core/apis/api.py index e715f59..fd003b1 100644 --- a/xrmogen/core/apis/api.py +++ b/xrmogen/core/apis/api.py @@ -1,4 +1,5 @@ from mmcv import Config + from .test import test_mogen from .train import train_mogen diff --git a/xrmogen/core/apis/helper.py b/xrmogen/core/apis/helper.py index a87f08f..5a9379d 100644 --- a/xrmogen/core/apis/helper.py +++ b/xrmogen/core/apis/helper.py @@ -1,38 +1,32 @@ import argparse import importlib -import warnings -from functools import partial, reduce - -import torch -from mmcv import Config -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel, collate -from mmcv.runner import (DistSamplerSeedHook, IterBasedRunner, OptimizerHook, - build_optimizer, get_dist_info) +from mmcv.runner import build_optimizer, get_dist_info from torch.utils.data import DataLoader, RandomSampler from xrmogen.datasets import DistributedSampler, build_dataset -__all__ = ['parse_args', 'build_dataloader', 'get_optimizer', 'register_hooks', \ - 'get_runner'] +__all__ = [ + 'parse_args', 'build_dataloader', 'get_optimizer', 'register_hooks', + 'get_runner' +] def parse_args(): parser = argparse.ArgumentParser(description='train a nerf') - parser.add_argument('--config', - help='train config file path', - default='configs/nerfs/nerf_base01.py') - parser.add_argument('--dataname', - help='data name in dataset', - default='ficus') - parser.add_argument('--test_only', - help='set to influence on testset once', - action='store_true') + parser.add_argument( + '--config', + help='train config file path', + default='configs/nerfs/nerf_base01.py') + parser.add_argument( + '--dataname', help='data name in dataset', default='ficus') + parser.add_argument( + '--test_only', + help='set to influence on testset once', + action='store_true') args = parser.parse_args() return args - - def build_dataloader(cfg, mode='train'): num_gpus = cfg.num_gpus @@ -48,13 +42,14 @@ def build_dataloader(cfg, mode='train'): bs_per_gpu = loader_cfg['batch_size'] # 分到每个gpu的bs数 bs_all_gpus = bs_per_gpu * num_gpus # 总的bs数 - data_loader = DataLoader(dataset, - batch_size=bs_all_gpus, - sampler=sampler, - num_workers=num_workers, - # collate_fn=partial(collate, - # samples_per_gpu=bs_per_gpu), - shuffle=False) + data_loader = DataLoader( + dataset, + batch_size=bs_all_gpus, + sampler=sampler, + num_workers=num_workers, + # collate_fn=partial(collate, + # samples_per_gpu=bs_per_gpu), + shuffle=False) return data_loader, dataset @@ -65,6 +60,7 @@ def get_optimizer(model, cfg): def register_hooks(hook_cfgs, **variables): + def get_variates(hook_cfg): variates = {} if 'variables' in hook_cfg: diff --git a/xrmogen/core/apis/test.py b/xrmogen/core/apis/test.py index 03ea26e..bdfc4ca 100644 --- a/xrmogen/core/apis/test.py +++ b/xrmogen/core/apis/test.py @@ -1,12 +1,9 @@ -import warnings - import torch -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel, collate -from mmcv.runner import EpochBasedRunner, get_dist_info, init_dist +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import init_dist from xrmogen.models.builder import build_dance_models from xrmogen.utils import get_root_logger - from .helper import build_dataloader, get_runner, register_hooks @@ -16,8 +13,10 @@ def test_mogen(cfg): Args: cfg (dict): The config dict for test, the same config as train. the difference between test and val is: - in test phase, use 'EpochBasedRunner' to influence all testset, in one iter - in val phase, use 'IterBasedRunner' to influence 1/N testset, in one epoch (several iters) + in test phase, use 'EpochBasedRunner' + to influence all testset, in one iter + in val phase, use 'IterBasedRunner' + to influence 1/N testset, in one epoch (several iters) """ cfg.workflow = [('val', 1)] # only run val_step one epoch @@ -40,10 +39,11 @@ def test_mogen(cfg): network = MMDataParallel(network.cuda(), device_ids=[0]) Runner = get_runner(cfg.test_runner) - runner = Runner(network, - work_dir=cfg.work_dir, - logger=get_root_logger(log_level=cfg.log_level), - meta=None) + runner = Runner( + network, + work_dir=cfg.work_dir, + logger=get_root_logger(log_level=cfg.log_level), + meta=None) runner.timestamp = cfg.get('timestamp', None) register_hooks(cfg.test_hooks, **locals()) diff --git a/xrmogen/core/apis/train.py b/xrmogen/core/apis/train.py index 824f247..782f08e 100644 --- a/xrmogen/core/apis/train.py +++ b/xrmogen/core/apis/train.py @@ -1,13 +1,10 @@ import os -import warnings - import torch -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel, collate -from mmcv.runner import EpochBasedRunner, get_dist_info, init_dist +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import init_dist from xrmogen.models.builder import build_dance_models from xrmogen.utils import get_root_logger - from .helper import build_dataloader, get_optimizer, get_runner, register_hooks @@ -20,10 +17,9 @@ def train_mogen(cfg): 1. build dataloader 2. optimizer 3. build model - 4. build runnder + 4. build runnder """ - train_loader, trainset = build_dataloader(cfg, mode='train') val_loader, valset = build_dataloader(cfg, mode='val') dataloaders = [train_loader, val_loader] @@ -45,15 +41,15 @@ def train_mogen(cfg): network = MMDataParallel(network.cuda(), device_ids=[0]) Runner = get_runner(cfg.train_runner) - runner = Runner(network, - optimizer=optimizer, - work_dir=cfg.work_dir, - logger=get_root_logger(log_level=cfg.log_level), - meta=None) + runner = Runner( + network, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=get_root_logger(log_level=cfg.log_level), + meta=None) runner.timestamp = cfg.get('timestamp', None) - # register hooks print('register hooks...', flush=True) diff --git a/xrmogen/core/hooks/__init__.py b/xrmogen/core/hooks/__init__.py index 0fbfba3..85bdbb8 100644 --- a/xrmogen/core/hooks/__init__.py +++ b/xrmogen/core/hooks/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .test_hooks import SaveTestDancePKLHook -from .validation_hooks import SaveDancePKLHook, SetValPipelineHook from .train_hooks import PassEpochNumberToModelHook +from .validation_hooks import SaveDancePKLHook, SetValPipelineHook __all__ = [ - 'SaveDancePKLHook', 'SetValPipelineHook', 'PassEpochNumberToModelHook', 'SaveTestDancePKLHook' + 'SaveDancePKLHook', 'SetValPipelineHook', 'PassEpochNumberToModelHook', + 'SaveTestDancePKLHook' ] diff --git a/xrmogen/core/hooks/test_hooks.py b/xrmogen/core/hooks/test_hooks.py index 7d4d7d7..2321792 100644 --- a/xrmogen/core/hooks/test_hooks.py +++ b/xrmogen/core/hooks/test_hooks.py @@ -3,15 +3,12 @@ # @Last Modified by: zcy # @Last Modified time: 2022-06-15 17:02:42 -import json -import os - -import imageio +import mmcv import numpy as np -import torch +import os from mmcv.runner import get_dist_info from mmcv.runner.hooks import HOOKS, Hook -import mmcv + @HOOKS.register_module() class SaveTestDancePKLHook(Hook): @@ -20,18 +17,15 @@ def __init__(self, save_folder='validation'): self.save_folder = save_folder self.count = 0 - def before_val_epoch(self, runner): - """ - prepare experiment folder - experiments - """ + """prepare experiment folder experiments.""" self.dance_results = {} def after_val_iter(self, runner): rank, _ = get_dist_info() if rank == 0: - dance_poses, dance_name = runner.outputs['output_pose'], runner.outputs['file_name'] + dance_poses, dance_name = runner.outputs[ + 'output_pose'], runner.outputs['file_name'] self.dance_results[dance_name] = dance_poses def after_val_epoch(self, runner): @@ -41,15 +35,16 @@ def after_val_epoch(self, runner): print(len(self.dance_results), flush=True) - - store_dir = os.path.join(runner.work_dir, self.save_folder, 'epoch' + str(cur_epoch)) + store_dir = os.path.join(runner.work_dir, self.save_folder, + 'epoch' + str(cur_epoch)) os.makedirs(store_dir, exist_ok=True) for key in self.dance_results: np_dance = self.dance_results[key].cpu().data.numpy()[0] root = np_dance[:, :3] np_dance = np_dance + np.tile(root, (1, 24)) np_dance[:, :3] = root - mmcv.dump(np_dance[None], os.path.join(store_dir, key + '.pkl')) + mmcv.dump(np_dance[None], os.path.join(store_dir, + key + '.pkl')) - # need to manually add 1 here + # need to manually add 1 here runner._epoch += 1 diff --git a/xrmogen/core/hooks/train_hooks.py b/xrmogen/core/hooks/train_hooks.py index a461744..260440d 100644 --- a/xrmogen/core/hooks/train_hooks.py +++ b/xrmogen/core/hooks/train_hooks.py @@ -1,11 +1,3 @@ - -import json -import os - -import imageio -import numpy as np -import torch -from mmcv.runner import get_dist_info from mmcv.runner.hooks import HOOKS, Hook @@ -15,18 +7,11 @@ class PassEpochNumberToModelHook(Hook): ndown: multiscales for mipnerf, set to 0 for others """ - def __init__(self, - ): - pass + def __init__(self, ): + + pass def before_train_epoch(self, runner): - """ - prepare experiment folder - experiments - """ + """prepare experiment folder experiments.""" runner.model.module._epoch = runner.epoch - - - - diff --git a/xrmogen/core/hooks/validation_hooks.py b/xrmogen/core/hooks/validation_hooks.py index db7a8b8..a0910db 100644 --- a/xrmogen/core/hooks/validation_hooks.py +++ b/xrmogen/core/hooks/validation_hooks.py @@ -3,14 +3,11 @@ # @Last Modified by: zcy # @Last Modified time: 2022-06-15 17:03:30 -import os - -import imageio +import mmcv import numpy as np -import torch +import os from mmcv.runner import get_dist_info -from mmcv.runner.hooks import HOOKS, Hook, CheckpointHook -import mmcv +from mmcv.runner.hooks import HOOKS, Hook @HOOKS.register_module() @@ -20,18 +17,15 @@ def __init__(self, save_folder='validation'): self.save_folder = save_folder self.count = 0 - def before_val_epoch(self, runner): - """ - prepare experiment folder - experiments - """ + """prepare experiment folder experiments.""" self.dance_results = {} def after_val_iter(self, runner): rank, _ = get_dist_info() if rank == 0: - dance_poses, dance_name = runner.outputs['output_pose'], runner.outputs['file_name'] + dance_poses, dance_name = runner.outputs[ + 'output_pose'], runner.outputs['file_name'] self.dance_results[dance_name] = dance_poses def after_val_epoch(self, runner): @@ -41,24 +35,27 @@ def after_val_epoch(self, runner): print(len(self.dance_results), flush=True) - store_dir = os.path.join(runner.work_dir, self.save_folder, 'epoch' + str(cur_epoch)) + store_dir = os.path.join(runner.work_dir, self.save_folder, + 'epoch' + str(cur_epoch)) os.makedirs(store_dir, exist_ok=True) for key in self.dance_results: np_dance = self.dance_results[key].cpu().data.numpy()[0] root = np_dance[:, :3] np_dance = np_dance + np.tile(root, (1, 24)) np_dance[:, :3] = root - mmcv.dump(np_dance[None], os.path.join(store_dir, key + '.pkl')) + mmcv.dump(np_dance[None], os.path.join(store_dir, + key + '.pkl')) - # need to manually add 1 here + # need to manually add 1 here @HOOKS.register_module() class SetValPipelineHook(Hook): """pass val dataset's pipeline to network.""" + def __init__(self, valset=None): self.val_pipeline = valset.pipeline def before_run(self, runner): # only run once runner.model.module.set_val_pipeline(self.val_pipeline) - del self.val_pipeline \ No newline at end of file + del self.val_pipeline diff --git a/xrmogen/core/runner/__init__.py b/xrmogen/core/runner/__init__.py index 9014407..522d88d 100644 --- a/xrmogen/core/runner/__init__.py +++ b/xrmogen/core/runner/__init__.py @@ -1,5 +1,4 @@ -from .base import MoGenTrainRunner, MoGenTestRunner -from .dance_runner import DanceTrainRunner, DanceTestRunner +from .dance_runner import DanceTestRunner, DanceTrainRunner __all__ = [ 'DanceTrainRunner', diff --git a/xrmogen/core/runner/dance_runner.py b/xrmogen/core/runner/dance_runner.py index ed48e53..d9a7093 100644 --- a/xrmogen/core/runner/dance_runner.py +++ b/xrmogen/core/runner/dance_runner.py @@ -1,13 +1,9 @@ -import time -import warnings - -import mmcv -import torch from mmcv.runner import EpochBasedRunner class DanceTrainRunner(EpochBasedRunner): pass + class DanceTestRunner(EpochBasedRunner): pass diff --git a/xrmogen/core/visualization/__init__.py b/xrmogen/core/visualization/__init__.py index cc2d064..6c94288 100644 --- a/xrmogen/core/visualization/__init__.py +++ b/xrmogen/core/visualization/__init__.py @@ -1,3 +1,3 @@ -from .visualize_keypoints3d import \ - visualize_keypoints3d +from .visualize_keypoints3d import visualize_keypoints3d + __all__ = ['visualize_keypoints3d'] diff --git a/xrmogen/core/visualization/visualize_keypoints3d.py b/xrmogen/core/visualization/visualize_keypoints3d.py index da60f93..1379a93 100644 --- a/xrmogen/core/visualization/visualize_keypoints3d.py +++ b/xrmogen/core/visualization/visualize_keypoints3d.py @@ -1,9 +1,9 @@ # yapf: disable import numpy as np +from scipy.spatial.transform import Rotation as scipy_Rotation from typing import Union from xrmogen.data_structure.keypoints import Keypoints -from scipy.spatial.transform import Rotation as scipy_Rotation try: from mmhuman3d.core.visualization.visualize_keypoints3d import ( @@ -52,8 +52,7 @@ def visualize_keypoints3d( keypoints_np = keypoints.to_numpy() kps3d = keypoints_np.get_keypoints()[..., :3].copy() rotation = scipy_Rotation.from_euler('zxy', [180, 0, 180], degrees=True) - kps3d = rotation.apply( - kps3d.reshape(-1, 3)).reshape( + kps3d = rotation.apply(kps3d.reshape(-1, 3)).reshape( keypoints_np.get_frame_number(), keypoints_np.get_person_number(), keypoints_np.get_keypoints_number(), 3) if keypoints_np.get_person_number() == 1: diff --git a/xrmogen/data_structure/keypoints.py b/xrmogen/data_structure/keypoints.py index 272d31d..df03a9d 100644 --- a/xrmogen/data_structure/keypoints.py +++ b/xrmogen/data_structure/keypoints.py @@ -4,7 +4,6 @@ import torch from typing import Any, Union from xrprimer.utils.log_utils import get_logger - from xrprimer.utils.path_utils import ( Existence, check_path_existence, check_path_suffix, ) diff --git a/xrmogen/datasets/__init__.py b/xrmogen/datasets/__init__.py index bf6b9d1..68485b8 100644 --- a/xrmogen/datasets/__init__.py +++ b/xrmogen/datasets/__init__.py @@ -1,10 +1,5 @@ -from .builder import DATASETS, build_dataset from .aistpp_dataset import AISTppDataset +from .builder import DATASETS, build_dataset from .samplers import DistributedSampler -__all__ = [ - 'AISTppDataset', - 'DATASETS', - 'build_dataset', - 'DistributedSampler' -] +__all__ = ['AISTppDataset', 'DATASETS', 'build_dataset', 'DistributedSampler'] diff --git a/xrmogen/datasets/aistpp_dataset.py b/xrmogen/datasets/aistpp_dataset.py index fe97cef..4d9410a 100644 --- a/xrmogen/datasets/aistpp_dataset.py +++ b/xrmogen/datasets/aistpp_dataset.py @@ -1,15 +1,18 @@ - -from .load_data.load_music_dance_data import load_train_data_aist, load_test_data_aist -""" Define the paired music-dance dataset. """ -import numpy as np -import torch -import torch.utils.data +# yapf: disable from torch.utils.data import Dataset + from .builder import DATASETS +from .load_data.load_music_dance_data import ( + load_test_data_aist, load_train_data_aist, +) from .pipeline.compose import Compose +# yapf: enable + + @DATASETS.register_module() class AISTppDataset(Dataset): + def __init__(self, data_config, pipeline): self.cfg = data_config self.dances = None @@ -18,14 +21,15 @@ def __init__(self, data_config, pipeline): self._init_load() self.pipeline = Compose(pipeline) - def _init_load(self): if self.mode == 'train': musics, dances, fnames = load_train_data_aist(self.cfg) elif self.mode == 'test': musics, dances, fnames = load_test_data_aist(self.cfg) - - print(len(musics), musics[0].shape, len(dances), dances[0].shape, len(fnames)) + + print( + len(musics), musics[0].shape, len(dances), dances[0].shape, + len(fnames)) self.musics = musics self.dances = dances self.fnames = fnames @@ -40,5 +44,3 @@ def __getitem__(self, index): 'file_names': self.fnames[index], } return self.pipeline(data) - - diff --git a/xrmogen/datasets/builder.py b/xrmogen/datasets/builder.py index 38b9395..94acfd2 100644 --- a/xrmogen/datasets/builder.py +++ b/xrmogen/datasets/builder.py @@ -1,9 +1,4 @@ -import numpy as np -import torch -from mmcv.parallel import collate -from mmcv.runner import get_dist_info -from mmcv.utils import Registry, build_from_cfg, digit_version -from torch.utils.data import DataLoader +from mmcv.utils import Registry, build_from_cfg DATASETS = Registry('dataset') PIPELINES = Registry('pipeline') diff --git a/xrmogen/datasets/load_data/load_music_dance_data.py b/xrmogen/datasets/load_data/load_music_dance_data.py index a74ab1f..b8ffdb5 100644 --- a/xrmogen/datasets/load_data/load_music_dance_data.py +++ b/xrmogen/datasets/load_data/load_music_dance_data.py @@ -8,17 +8,18 @@ def load_train_data_aist(cfg): interval = cfg.seq_len move = cfg.move rotmat = cfg.rotmat - external_wav= cfg.external_wav if hasattr(cfg, 'external_wav') else None - external_wav_rate = cfg.external_wav_rate if hasattr(cfg, 'external_wav_rate') else None + external_wav = cfg.external_wav if hasattr(cfg, 'external_wav') else None + external_wav_rate = cfg.external_wav_rate if hasattr( + cfg, 'external_wav_rate') else None tot = 0 music_data, dance_data, input_names = [], [], [] - + # traverse all data fnames = sorted(os.listdir(data_dir)) for fname in fnames: - + path = os.path.join(data_dir, fname) with open(path) as f: @@ -26,56 +27,60 @@ def load_train_data_aist(cfg): np_music = np.array(sample_dict['music_array']).astype(np.float32) if external_wav is not None: - wav_path = os.path.join(external_wav, fname.split('_')[-2] + '.json') + wav_path = os.path.join(external_wav, + fname.split('_')[-2] + '.json') with open(wav_path) as ff: sample_dict_wav = json.loads(ff.read()) - np_music = np.array(sample_dict_wav['music_array']).astype(np.float32) - + np_music = np.array(sample_dict_wav['music_array']).astype( + np.float32) + np_dance = np.array(sample_dict['dance_array']).astype(np.float32) if not rotmat: root = np_dance[:, :3] # the root - np_dance = np_dance - np.tile(root, (1, 24)) # Calculate relative offset with respect to root + np_dance = np_dance - np.tile( + root, + (1, 24)) # Calculate relative offset with respect to root np_dance[:, :3] = root - music_sample_rate = external_wav_rate if external_wav is not None else 1 + music_sample_rate = external_wav_rate \ + if external_wav is not None else 1 - if interval is not None: # just sample a piece of music + if interval is not None: # just sample a piece of music seq_len, dim = np_music.shape for i in range(0, seq_len, move): i_sample = i // music_sample_rate interval_sample = interval // music_sample_rate - music_sub_seq = np_music[i_sample: i_sample + interval_sample] - dance_sub_seq = np_dance[i: i + interval] - + music_sub_seq = np_music[i_sample:i_sample + + interval_sample] + dance_sub_seq = np_dance[i:i + interval] - - if len(music_sub_seq) == interval_sample and len(dance_sub_seq) == interval: + if len(music_sub_seq) == interval_sample and len( + dance_sub_seq) == interval: music_sub_seq_pad = music_sub_seq music_data.append(music_sub_seq_pad) dance_data.append(dance_sub_seq) input_names.append(fname.split('.')[0]) tot += 1 - else: music_data.append(np_music) dance_data.append(np_dance) input_names.append(fname.split('.')[0]) - return music_data, dance_data, input_names def load_test_data_aist(cfg): data_dir = cfg.data_dir - move = cfg.move + move = cfg.move rotmat = cfg.rotmat - external_wav= cfg.external_wav if hasattr(cfg, 'external_wav') else None - external_wav_rate = cfg.external_wav_rate if hasattr(cfg, 'external_wav_rate') else None + external_wav = cfg.external_wav if hasattr(cfg, 'external_wav') else None + external_wav_rate = cfg.external_wav_rate if hasattr( + cfg, 'external_wav_rate') else None music_data, dance_data, input_names = [], [], [] fnames = sorted(os.listdir(data_dir)) @@ -86,35 +91,43 @@ def load_test_data_aist(cfg): sample_dict = json.loads(f.read()) np_music = np.array(sample_dict['music_array']).astype(np.float32) - if external_wav is not None: # using music features from external files - wav_path = os.path.join(external_wav, fname.split('_')[-2] + '.json') + if external_wav is not None: + # using music features from external files + wav_path = os.path.join(external_wav, + fname.split('_')[-2] + '.json') with open(wav_path) as ff: sample_dict_wav = json.loads(ff.read()) - np_music = np.array(sample_dict_wav['music_array']).astype(np.float32) - + np_music = np.array(sample_dict_wav['music_array']).astype( + np.float32) + if 'dance_array' in sample_dict: - np_dance = np.array(sample_dict['dance_array']).astype(np.float32) + np_dance = np.array(sample_dict['dance_array']).astype( + np.float32) if not rotmat: root = np_dance[:, :3] # the root - np_dance = np_dance - np.tile(root, (1, 24)) # Calculate relative offset with respect to root + np_dance = np_dance - np.tile( + root, + (1, + 24)) # Calculate relative offset with respect to root np_dance[:, :3] = root - for kk in range((len(np_dance) // move + 1) * move - len(np_dance) ): + for kk in range((len(np_dance) // move + 1) * move - + len(np_dance)): np_dance = np.append(np_dance, np_dance[-1:], axis=0) dance_data.append(np_dance) else: np_dance = None dance_data = None - music_move = external_wav_rate if external_wav is not None else move - + music_move = external_wav_rate \ + if external_wav is not None else move + # fully devisable - for kk in range((len(np_music) // music_move + 1) * music_move - len(np_music)): + for kk in range((len(np_music) // music_move + 1) * music_move - + len(np_music)): np_music = np.append(np_music, np_music[-1:], axis=0) music_data.append(np_music) input_names.append(fname.split('.')[0]) return music_data, dance_data, input_names - - diff --git a/xrmogen/datasets/pipeline/compose.py b/xrmogen/datasets/pipeline/compose.py index aeb18e7..1824e16 100644 --- a/xrmogen/datasets/pipeline/compose.py +++ b/xrmogen/datasets/pipeline/compose.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Sequence - import mmcv import numpy as np import torch +from collections.abc import Sequence from mmcv.utils import build_from_cfg from ..builder import PIPELINES @@ -17,6 +16,7 @@ class Compose: transforms (list[dict | callable]): Either config dicts of transforms or transform objects. """ + def __init__(self, transforms): assert isinstance(transforms, Sequence) self.transforms = [] @@ -80,6 +80,7 @@ class ToTensor: Args: keys (Sequence[str]): Required keys to be converted. """ + def __init__(self, keys, **kwargs): self.keys = keys self.kwargs = kwargs diff --git a/xrmogen/datasets/samplers/distributed_sampler.py b/xrmogen/datasets/samplers/distributed_sampler.py index 3e8b621..c4d27ee 100644 --- a/xrmogen/datasets/samplers/distributed_sampler.py +++ b/xrmogen/datasets/samplers/distributed_sampler.py @@ -9,6 +9,7 @@ class DistributedSampler(_DistributedSampler): In pytorch of lower versions, there is no ``shuffle`` argument. This child class will port one to DistributedSampler. """ + def __init__(self, dataset, num_replicas=None, diff --git a/xrmogen/models/__init__.py b/xrmogen/models/__init__.py index 13041a8..ffbf38b 100644 --- a/xrmogen/models/__init__.py +++ b/xrmogen/models/__init__.py @@ -1,4 +1,5 @@ from .dance_models.bailando.bailando import Bailando # from .fact.fact import FACT from .dance_models.dancerev.model import DanceRevolution + __all__ = ['Bailando', 'DanceRevolution'] diff --git a/xrmogen/models/builder.py b/xrmogen/models/builder.py index 1a94ffa..595d7bc 100644 --- a/xrmogen/models/builder.py +++ b/xrmogen/models/builder.py @@ -1,6 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings - from mmcv.cnn import MODELS as MMCV_MODELS from mmcv.utils import Registry @@ -8,7 +6,7 @@ DANCE_MODELS = MODELS + def build_dance_models(cfg): # print(cfg.keys()) return DANCE_MODELS.build(cfg) - diff --git a/xrmogen/models/dance_models/bailando/bailando.py b/xrmogen/models/dance_models/bailando/bailando.py index fa9bccb..e1f325c 100644 --- a/xrmogen/models/dance_models/bailando/bailando.py +++ b/xrmogen/models/dance_models/bailando/bailando.py @@ -1,15 +1,10 @@ - -import math -import logging - import torch import torch.nn as nn -from torch.nn import functional as F -from .vqvae.sep_vqvae_root import SepVQVAER -from .gpt.cross_cond_gpt import CrossCondGPT - from ...builder import DANCE_MODELS +from .gpt.cross_cond_gpt import CrossCondGPT +from .vqvae.sep_vqvae_root import SepVQVAER + @DANCE_MODELS.register_module() class Bailando(nn.Module): @@ -19,13 +14,13 @@ def __init__(self, model_config): self.bailando_phase = model_config['bailando_phase'] self.vqvae = SepVQVAER(model_config.vqvae) self.gpt = CrossCondGPT(model_config.gpt) - + # self.val_results = {} - + def train_step(self, data, optimizer, **kwargs): train_phase = self.bailando_phase - music_seq, pose_seq = data['music'], data['dance'] + music_seq, pose_seq = data['music'], data['dance'] optimizer.zero_grad() @@ -49,8 +44,12 @@ def train_step(self, data, optimizer, **kwargs): with torch.no_grad(): quants_pred = self.vqvae.encode(pose_seq) if isinstance(quants_pred, tuple): - quants_input = tuple(quants_pred[ii][0][:, :-1].clone().detach() for ii in range(len(quants_pred))) - quants_target = tuple(quants_pred[ii][0][:, 1:].clone().detach() for ii in range(len(quants_pred))) + quants_input = tuple( + quants_pred[ii][0][:, :-1].clone().detach() + for ii in range(len(quants_pred))) + quants_target = tuple( + quants_pred[ii][0][:, 1:].clone().detach() + for ii in range(len(quants_pred))) else: quants = quants_pred[0] quants_input = quants[:, :-1].clone().detach() @@ -61,14 +60,8 @@ def train_step(self, data, optimizer, **kwargs): else: raise NotImplementedError - stats = { - 'loss': loss.item() - } - outputs = { - 'loss': loss, - 'log_vars': stats, - 'num_samples': out.size(1) - } + stats = {'loss': loss.item()} + outputs = {'loss': loss, 'log_vars': stats, 'num_samples': out.size(1)} # loss.backward() # optimizer.step() @@ -77,26 +70,26 @@ def train_step(self, data, optimizer, **kwargs): def val_step(self, data, optimizer, **kwargs): return self.test_step(data, optimizer, **kwargs) - def test_step(self, data, optimizer, **kwargs): test_phase = self.bailando_phase music_seq, pose_seq = data['music'], data['dance'] self.eval() - + results = [] pose_seq[:, :, :3] = 0 with torch.no_grad(): if test_phase == 'motion vqvae': - + # print(pose_seq[0, 7, 6], ) pose_seq[:, :, :3] = 0 pose_seq_out, _, _ = self.vqvae(pose_seq, test_phase) - results.append(pose_seq_out) + results.append(pose_seq_out) elif test_phase == 'global velocity': - pose_seq[:, :-1, :3] = pose_seq[:, 1:, :3] - pose_seq[:, :-1, :3] + pose_seq[:, :-1, :3] = pose_seq[:, + 1:, :3] - pose_seq[:, :-1, :3] pose_seq[:, -1, :3] = pose_seq[:, -2, :3] pose_seq = pose_seq.clone().detach() @@ -105,15 +98,19 @@ def test_step(self, data, optimizer, **kwargs): global_vel = pose_seq_out[:, :, :3].clone() pose_seq_out[:, 0, :3] = 0 for iii in range(1, pose_seq_out.size(1)): - pose_seq_out[:, iii, :3] = pose_seq_out[:, iii-1, :3] + global_vel[:, iii-1, :] - results.append(pose_seq_out) + pose_seq_out[:, iii, : + 3] = pose_seq_out[:, iii - + 1, :3] + global_vel[:, iii - + 1, :] + results.append(pose_seq_out) elif test_phase == 'gpt': pose_seq[:, :, :3] = 0 quants = self.vqvae.encode(pose_seq) if isinstance(quants, tuple): - x = tuple(quants[i][0][:, :1].clone() for i in range(len(quants))) + x = tuple(quants[i][0][:, :1].clone() + for i in range(len(quants))) else: x = quants[0][:, :1].clone() @@ -124,13 +121,15 @@ def test_step(self, data, optimizer, **kwargs): global_vel = pose_sample[:, :, :3].clone() pose_sample[:, 0, :3] = 0 for iii in range(1, pose_sample.size(1)): - pose_sample[:, iii, :3] = pose_sample[:, iii-1, :3] + global_vel[:, iii-1, :] + pose_sample[:, + iii, :3] = pose_sample[:, iii - 1, : + 3] + global_vel[:, iii - + 1, :] results.append(pose_sample) else: raise NotImplementedError - # self.val_results.update({data['file_names'][0]: results[0]}) outputs = { 'output_pose': results[0], @@ -138,7 +137,3 @@ def test_step(self, data, optimizer, **kwargs): } return outputs - - - - diff --git a/xrmogen/models/dance_models/bailando/gpt/cross_cond_gpt.py b/xrmogen/models/dance_models/bailando/gpt/cross_cond_gpt.py index c788fa5..7ddee2c 100644 --- a/xrmogen/models/dance_models/bailando/gpt/cross_cond_gpt.py +++ b/xrmogen/models/dance_models/bailando/gpt/cross_cond_gpt.py @@ -1,23 +1,25 @@ -""" -GPT model: -- the initial stem consists of a combination of token encoding and a positional encoding -- the meat of it is a uniform sequence of Transformer blocks - - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block - - all blocks feed into a central residual pathway similar to resnets -- the final decoder is a linear projection into a vanilla Softmax classifier +"""GPT model: + +- the initial stem consists of a +combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of +Transformer blocks + - each Transformer is a sequential combination of a + 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway + similar to resnets +- the final decoder is a linear projection into a +vanilla Softmax classifier """ import math -import logging - import torch import torch.nn as nn from torch.nn import functional as F - class CrossCondGPT(nn.Module): - """ the full GPT language model, with a context size of block_size """ + """the full GPT language model, with a context size of block_size.""" def __init__(self, config): super().__init__() @@ -36,10 +38,17 @@ def sample(self, xs, cond, shift=None): block_shift = block_size x_up, x_down = xs for k in range(cond.size(1)): - x_cond_up = x_up if x_up.size(1) <= block_size else x_up[:, -(block_shift+(k-block_size-1)%(block_size-block_shift+1)):] - x_cond_down = x_down if x_down.size(1) <= block_size else x_down[:, -(block_shift+(k-block_size-1)%(block_size-block_shift+1)):] # crop context if needed - - cond_input = cond[:, :k+1] if k < block_size else cond[:, k-(block_shift+(k-block_size-1)%(block_size-block_shift+1))+1:k+1] + x_cond_up = x_up if x_up.size(1) <= block_size else x_up[:, -( + block_shift + (k - block_size - 1) % + (block_size - block_shift + 1)):] + x_cond_down = x_down if x_down.size( + 1) <= block_size else x_down[:, -( + block_shift + (k - block_size - 1) % + (block_size - block_shift + 1)):] # crop context if needed + + cond_input = cond[:, :k + 1] if k < block_size else cond[:, k - ( + block_shift + (k - block_size - 1) % + (block_size - block_shift + 1)) + 1:k + 1] logits, _ = self.forward((x_cond_up, x_cond_down), cond_input) @@ -61,15 +70,16 @@ def sample(self, xs, cond, shift=None): def forward(self, idxs, cond, targets=None): idx_up, idx_down = idxs - + targets_up, targets_down = None, None if targets is not None: targets_up, targets_down = targets - + feat = self.gpt_base(idx_up, idx_down, cond) - logits_up, logits_down, loss_up, loss_down = self.gpt_head(feat, targets) + logits_up, logits_down, loss_up, loss_down = self.gpt_head( + feat, targets) # logits_down, loss_down = self.down_half_gpt(feat, targets_down) - + if loss_up is not None and loss_down is not None: loss = loss_up + loss_down else: @@ -77,11 +87,14 @@ def forward(self, idxs, cond, targets=None): return (logits_up, logits_down), loss + class CausalCrossConditionalSelfAttention(nn.Module): - """ - A vanilla multi-head masked self-attention layer with a projection at the end. - It is possible to use torch.nn.MultiheadAttention here but I am including an - explicit implementation here to show that there is nothing too scary here. + """A vanilla multi-head masked self-attention layer with a projection at + the end. + + It is possible to use torch.nn.MultiheadAttention here but I am including + an explicit implementation here to show that there is nothing too scary + here. """ def __init__(self, config): @@ -96,33 +109,47 @@ def __init__(self, config): self.resid_drop = nn.Dropout(config.resid_pdrop) # output projection self.proj = nn.Linear(config.n_embd, config.n_embd) - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) + # causal mask to ensure that attention is only + # applied to the left in the input sequence + self.register_buffer( + 'mask', + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size)) self.n_head = config.n_head def forward(self, x, layer_past=None): B, T, C = x.size() # T = 3*t (music up down) - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + # calculate query, key, values for all heads + # in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, + C // self.n_head).transpose(1, + 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, + C // self.n_head).transpose(1, + 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, + C // self.n_head).transpose(1, + 2) # (B, nh, T, hs) t = T // 3 - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + # causal self-attention; Self-attend: + # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.mask[:,:,:t,:t].repeat(1, 1, 3, 3) == 0, float('-inf')) + att = att.masked_fill(self.mask[:, :, :t, :t].repeat(1, 1, 3, 3) == 0, + float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_drop(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view( + B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_drop(self.proj(y)) return y + class Block(nn.Module): - """ an unassuming Transformer block """ + """an unassuming Transformer block.""" def __init__(self, config): super().__init__() @@ -141,27 +168,29 @@ def forward(self, x): x = x + self.mlp(self.ln2(x)) return x + class CrossCondGPTBase(nn.Module): - """ the full GPT language model, with a context size of block_size """ + """the full GPT language model, with a context size of block_size.""" def __init__(self, config): super().__init__() - self.tok_emb_up = nn.Embedding(config.vocab_size_up, config.n_embd ) + self.tok_emb_up = nn.Embedding(config.vocab_size_up, config.n_embd) self.tok_emb_down = nn.Embedding(config.vocab_size_down, config.n_embd) - self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size*3, config.n_embd)) + self.pos_emb = nn.Parameter( + torch.zeros(1, config.block_size * 3, config.n_embd)) self.cond_emb = nn.Linear(config.n_music, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) # transformer - self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) - - self.block_size = config.block_size + self.blocks = nn.Sequential( + *[Block(config) for _ in range(config.n_layer)]) + self.block_size = config.block_size self.apply(self._init_weights) - - # logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + # logger.info("number of parameters: %e", + # sum(p.numel() for p in self.parameters())) def get_block_size(self): return self.block_size @@ -177,65 +206,96 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def configure_optimizers(self, train_config): - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. + """This long function is unfortunately doing something very simple and + is being very defensive: + + We are separating out all parameters of the model into two buckets: + those that will experience weight decay for regularization and those + that won't (biases, and layernorm/embedding weights). We are then + returning the PyTorch optimizer object. """ - # separate out all parameters to those that will and won't experience regularizing weight decay + # separate out all parameters to those that + # will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + elif pn.endswith('weight') and isinstance( + m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + elif pn.endswith('weight') and isinstance( + m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) - # special case the position embedding parameter in the root GPT module as not decayed + # special case the position embedding parameter + # in the root GPT module as not decayed no_decay.add('pos_emb') # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert len( + inter_params + ) == 0, 'parameters %s made it into both decay/no_decay sets!' % ( + str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0,\ + f'parameters {str(param_dict.keys() - union_params)} ' + \ + 'were not separated into either decay/no_decay set!' # create the pytorch optimizer object optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + 'params': [param_dict[pn] for pn in sorted(list(decay))], + 'weight_decay': train_config.weight_decay + }, + { + 'params': [param_dict[pn] for pn in sorted(list(no_decay))], + 'weight_decay': 0.0 + }, ] - optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + optimizer = torch.optim.AdamW( + optim_groups, + lr=train_config.learning_rate, + betas=train_config.betas) return optimizer def forward(self, idx_up, idx_down, cond): b, t = idx_up.size() - assert t <= self.block_size, "Cannot forward, model block size is exhausted." + assert t <= self.block_size,\ + 'Cannot forward, model block size is exhausted.' b, t = idx_down.size() - assert t <= self.block_size, "Cannot forward, model block size is exhausted." + assert t <= self.block_size,\ + 'Cannot forward, model block size is exhausted.' # forward the GPT model # if self.requires_head: - token_embeddings_up = self.tok_emb_up(idx_up) # each index maps to a (learnable) vector - token_embeddings_down = self.tok_emb_down(idx_down) # each index maps to a (learnable) vector - token_embeddings = torch.cat([self.cond_emb(cond), token_embeddings_up, token_embeddings_down ], dim=1) + token_embeddings_up = self.tok_emb_up( + idx_up) # each index maps to a (learnable) vector + token_embeddings_down = self.tok_emb_down( + idx_down) # each index maps to a (learnable) vector + token_embeddings = torch.cat( + [self.cond_emb(cond), token_embeddings_up, token_embeddings_down], + dim=1) + + position_embeddings = torch.cat( + [ + self.pos_emb[:, :t, :], + self.pos_emb[:, self.block_size:self.block_size + t, :], + self.pos_emb[:, self.block_size * 2:self.block_size * 2 + t, :] + ], + dim=1) # each position maps to a (learnable) vector - position_embeddings = torch.cat([self.pos_emb[:, :t, :], self.pos_emb[:, self.block_size:self.block_size+t, :], self.pos_emb[:, self.block_size*2:self.block_size*2+t, :]], dim=1) # each position maps to a (learnable) vector - x = self.drop(token_embeddings + position_embeddings) x = self.blocks(x) @@ -243,19 +303,23 @@ def forward(self, idx_up, idx_down, cond): return x + class CrossCondGPTHead(nn.Module): - """ the full GPT language model, with a context size of block_size """ + """the full GPT language model, with a context size of block_size.""" def __init__(self, config): super().__init__() - self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + self.blocks = nn.Sequential( + *[Block(config) for _ in range(config.n_layer)]) # decoder head self.ln_f = nn.LayerNorm(config.n_embd) self.block_size = config.block_size - self.head_up = nn.Linear(config.n_embd, config.vocab_size_up, bias=False) - self.head_down = nn.Linear(config.n_embd, config.vocab_size_down, bias=False) + self.head_up = nn.Linear( + config.n_embd, config.vocab_size_up, bias=False) + self.head_down = nn.Linear( + config.n_embd, config.vocab_size_down, bias=False) self.apply(self._init_weights) @@ -272,49 +336,70 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def configure_optimizers(self, train_config): - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. + """This long function is unfortunately doing something very simple and + is being very defensive: + + We are separating out all parameters of the model into two buckets: + those that will experience weight decay for regularization and those + that won't (biases, and layernorm/embedding weights). We are then + returning the PyTorch optimizer object. """ - # separate out all parameters to those that will and won't experience regularizing weight decay + # separate out all parameters to those that + # will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): - # weights of whitelist modules will be weight decayed + elif pn.endswith('weight') and isinstance( + m, whitelist_weight_modules): + # weights of whitelist modules will + # be weight decayed decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): - # weights of blacklist modules will NOT be weight decayed + elif pn.endswith('weight') and isinstance( + m, blacklist_weight_modules): + # weights of blacklist modules will + # NOT be weight decayed no_decay.add(fpn) - # special case the position embedding parameter in the root GPT module as not decayed + # special case the position embedding parameter + # in the root GPT module as not decayed no_decay.add('pos_emb') # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert len(inter_params) == 0,\ + f'parameters {str(inter_params)} made it into' +\ + ' both decay/no_decay sets!' + assert len(param_dict.keys() - union_params) == 0,\ + f'parameters {str(param_dict.keys() - union_params)}' +\ + ' were not separated' +\ + ' into either decay/no_decay set!' # create the pytorch optimizer object optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + 'params': [param_dict[pn] for pn in sorted(list(decay))], + 'weight_decay': train_config.weight_decay + }, + { + 'params': [param_dict[pn] for pn in sorted(list(no_decay))], + 'weight_decay': 0.0 + }, ] - optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + optimizer = torch.optim.AdamW( + optim_groups, + lr=train_config.learning_rate, + betas=train_config.betas) return optimizer def forward(self, x, targets=None): @@ -322,9 +407,9 @@ def forward(self, x, targets=None): x = self.blocks(x) x = self.ln_f(x) N, T, C = x.size() - t = T//3 - logits_up = self.head_up(x[:, t:t*2, :]) - logits_down = self.head_down(x[:, t*2:t*3, :]) # down half + t = T // 3 + logits_up = self.head_up(x[:, t:t * 2, :]) + logits_down = self.head_down(x[:, t * 2:t * 3, :]) # down half # if we are given some desired targets also calculate the loss loss_up, loss_down = None, None @@ -332,11 +417,10 @@ def forward(self, x, targets=None): if targets is not None: targets_up, targets_down = targets - loss_up = F.cross_entropy(logits_up.view(-1, logits_up.size(-1)), targets_up.view(-1)) - loss_down = F.cross_entropy(logits_down.view(-1, logits_down.size(-1)), targets_down.view(-1)) - + loss_up = F.cross_entropy( + logits_up.view(-1, logits_up.size(-1)), targets_up.view(-1)) + loss_down = F.cross_entropy( + logits_down.view(-1, logits_down.size(-1)), + targets_down.view(-1)) return logits_up, logits_down, loss_up, loss_down - - - diff --git a/xrmogen/models/dance_models/bailando/vqvae/bottleneck.py b/xrmogen/models/dance_models/bailando/vqvae/bottleneck.py index 95bb1f2..d5064dd 100644 --- a/xrmogen/models/dance_models/bailando/vqvae/bottleneck.py +++ b/xrmogen/models/dance_models/bailando/vqvae/bottleneck.py @@ -5,6 +5,7 @@ class BottleneckBlock(nn.Module): + def __init__(self, k_bins, emb_width, mu): super().__init__() self.k_bins = k_bins @@ -29,7 +30,7 @@ def _tile(self, x): return x def init_k(self, x): - mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins + _, emb_width, k_bins = self.mu, self.emb_width, self.k_bins self.init = True # init k_w using random vectors from x y = self._tile(x) @@ -41,7 +42,7 @@ def init_k(self, x): self.k_elem = t.ones(k_bins, device=self.k.device) def restore_k(self, num_tokens=None, threshold=1.0): - mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins + _, emb_width, k_bins = self.mu, self.emb_width, self.k_bins self.init = True assert self.k.shape == (k_bins, emb_width) self.k_sum = self.k.clone() @@ -56,7 +57,8 @@ def update_k(self, x, x_l): mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins with t.no_grad(): # Calculate new centres - x_l_onehot = t.zeros(k_bins, x.shape[0], device=x.device) # k_bins, N * L + x_l_onehot = t.zeros( + k_bins, x.shape[0], device=x.device) # k_bins, N * L x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1) _k_sum = t.matmul(x_l_onehot, x) # k_bins, w @@ -64,23 +66,23 @@ def update_k(self, x, x_l): y = self._tile(x) _k_rand = y[t.randperm(y.shape[0])][:k_bins] - # Update centres old_k = self.k self.k_sum = mu * self.k_sum + (1. - mu) * _k_sum # w, k_bins self.k_elem = mu * self.k_elem + (1. - mu) * _k_elem # k_bins usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float() - self.k = usage * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) \ - + (1 - usage) * _k_rand - _k_prob = _k_elem / t.sum(_k_elem) # x_l_onehot.mean(dim=-1) # prob of each bin - entropy = -t.sum(_k_prob * t.log(_k_prob + 1e-8)) # entropy ie how diverse + self.k = usage * ( + self.k_sum.view(k_bins, emb_width) / + self.k_elem.view(k_bins, 1)) + \ + (1 - usage) * _k_rand + _k_prob = _k_elem / t.sum( + _k_elem) # x_l_onehot.mean(dim=-1) # prob of each bin + entropy = -t.sum( + _k_prob * t.log(_k_prob + 1e-8)) # entropy ie how diverse used_curr = (_k_elem >= self.threshold).sum() usage = t.sum(usage) dk = t.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape)) - return dict(entropy=entropy, - used_curr=used_curr, - usage=usage, - dk=dk) + return dict(entropy=entropy, used_curr=used_curr, usage=usage, dk=dk) def preprocess(self, x): # NCT -> NTC -> [NT, C] @@ -90,13 +92,17 @@ def preprocess(self, x): if x.shape[-1] == self.emb_width: prenorm = t.norm(x - t.mean(x)) / np.sqrt(np.prod(x.shape)) elif x.shape[-1] == 2 * self.emb_width: - x1, x2 = x[...,:self.emb_width], x[...,self.emb_width:] - prenorm = (t.norm(x1 - t.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (t.norm(x2 - t.mean(x2)) / np.sqrt(np.prod(x2.shape))) + x1, x2 = x[..., :self.emb_width], x[..., self.emb_width:] + prenorm = ( + t.norm(x1 - t.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( + t.norm(x2 - t.mean(x2)) / np.sqrt(np.prod(x2.shape))) # Normalise x = x1 + x2 else: - assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}" + assert False, \ + f'Expected {x.shape[-1]} to be' +\ + f' (1 or 2) * {self.emb_width}' return x, prenorm def postprocess(self, x_l, x_d, x_shape): @@ -109,8 +115,9 @@ def postprocess(self, x_l, x_d, x_shape): def quantise(self, x): # Calculate latent code x_l k_w = self.k.t() - distance = t.sum(x ** 2, dim=-1, keepdim=True) - 2 * t.matmul(x, k_w) + t.sum(k_w ** 2, dim=0, - keepdim=True) # (N * L, b) + distance = t.sum( + x**2, dim=-1, keepdim=True) - 2 * t.matmul(x, k_w) + t.sum( + k_w**2, dim=0, keepdim=True) # (N * L, b) min_distance, x_l = t.min(distance, dim=-1) fit = t.mean(min_distance) return x_l, fit @@ -164,35 +171,43 @@ def forward(self, x, update_k=True): update_metrics = {} # Loss - commit_loss = t.norm(x_d.detach() - x) ** 2 / np.prod(x.shape) + commit_loss = t.norm(x_d.detach() - x)**2 / np.prod(x.shape) # Passthrough x_d = x + (x_d - x).detach() # Postprocess - x_l, x_d = self.postprocess(x_l, x_d, (N,T)) - return x_l, x_d, commit_loss, dict(fit=fit, - pn=prenorm, - **update_metrics) + x_l, x_d = self.postprocess(x_l, x_d, (N, T)) + return x_l, x_d, commit_loss, dict( + fit=fit, pn=prenorm, **update_metrics) class Bottleneck(nn.Module): + def __init__(self, l_bins, emb_width, mu, levels): super().__init__() self.levels = levels - level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu) + level_block = lambda level: BottleneckBlock( # noqa: E731 + l_bins, emb_width, mu) self.level_blocks = nn.ModuleList() for level in range(self.levels): self.level_blocks.append(level_block(level)) def encode(self, xs): - zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)] + zs = [ + level_block.encode(x) + for (level_block, x) in zip(self.level_blocks, xs) + ] return zs def decode(self, zs, start_level=0, end_level=None): if end_level is None: end_level = self.levels - xs_quantised = [level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs)] + xs_quantised = [ + level_block.decode(z) + for (level_block, + z) in zip(self.level_blocks[start_level:end_level], zs) + ] return xs_quantised def forward(self, xs): @@ -200,7 +215,8 @@ def forward(self, xs): for level in range(self.levels): level_block = self.level_blocks[level] x = xs[level] - z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training) + z, x_quantised, commit_loss, metric = level_block( + x, update_k=self.training) zs.append(z) if not self.training: # Be extra paranoid and make sure the encoder weights can't @@ -212,11 +228,15 @@ def forward(self, xs): metrics.append(metric) return zs, xs_quantised, commit_losses, metrics + class NoBottleneckBlock(nn.Module): + def restore_k(self): pass + class NoBottleneck(nn.Module): + def __init__(self, levels): super().__init__() self.level_blocks = nn.ModuleList() @@ -235,6 +255,8 @@ def decode(self, zs, start_level=0, end_level=None): def forward(self, xs): zero = t.zeros(()).cuda() commit_losses = [zero for _ in range(self.levels)] - metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)] + metrics = [ + dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) + for _ in range(self.levels) + ] return xs, xs, commit_losses, metrics - diff --git a/xrmogen/models/dance_models/bailando/vqvae/encdec.py b/xrmogen/models/dance_models/bailando/vqvae/encdec.py index 5aa90ab..8b1cb91 100644 --- a/xrmogen/models/dance_models/bailando/vqvae/encdec.py +++ b/xrmogen/models/dance_models/bailando/vqvae/encdec.py @@ -1,15 +1,25 @@ -import torch as t import torch.nn as nn -from .resnet import Resnet, Resnet1D + +from .resnet import Resnet1D def assert_shape(x, exp_shape): - assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}" + assert x.shape == exp_shape, f'Expected {exp_shape} got {x.shape}' + class EncoderConvBlock(nn.Module): - def __init__(self, input_emb_width, output_emb_width, down_t, - stride_t, width, depth, m_conv, - dilation_growth_rate=1, dilation_cycle=None, zero_out=False, + + def __init__(self, + input_emb_width, + output_emb_width, + down_t, + stride_t, + width, + depth, + m_conv, + dilation_growth_rate=1, + dilation_cycle=None, + zero_out=False, res_scale=False): super().__init__() blocks = [] @@ -17,41 +27,73 @@ def __init__(self, input_emb_width, output_emb_width, down_t, if down_t > 0: for i in range(down_t): block = nn.Sequential( - # nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t, padding_mode='replicate'), - nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t), - Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale), + # nn.Conv1d(input_emb_width if i == 0 + # else width, width, filter_t, stride_t, + # pad_t, padding_mode='replicate'), + nn.Conv1d(input_emb_width if i == 0 else width, width, + filter_t, stride_t, pad_t), + Resnet1D(width, depth, m_conv, dilation_growth_rate, + dilation_cycle, zero_out, res_scale), ) blocks.append(block) block = nn.Conv1d(width, output_emb_width, 3, 1, 1) - # block = nn.Conv1d(width, output_emb_width, 3, 1, 1, padding_mode='replicate') + # block = nn.Conv1d(width, output_emb_width, 3, 1, 1, + # padding_mode='replicate') blocks.append(block) self.model = nn.Sequential(*blocks) def forward(self, x): return self.model(x) + class DecoderConvBock(nn.Module): - def __init__(self, input_emb_width, output_emb_width, down_t, - stride_t, width, depth, m_conv, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_decoder_dilation=False, checkpoint_res=False): + + def __init__(self, + input_emb_width, + output_emb_width, + down_t, + stride_t, + width, + depth, + m_conv, + dilation_growth_rate=1, + dilation_cycle=None, + zero_out=False, + res_scale=False, + reverse_decoder_dilation=False, + checkpoint_res=False): super().__init__() blocks = [] if down_t > 0: filter_t, pad_t = stride_t * 2, stride_t // 2 block = nn.Conv1d(output_emb_width, width, 3, 1, 1) - # block = nn.Conv1d(output_emb_width, width, 3, 1, 1, padding_mode='replicate') + # block = nn.Conv1d(output_emb_width, width, 3, 1, 1, + # padding_mode='replicate') blocks.append(block) for i in range(down_t): block = nn.Sequential( - Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out, res_scale=res_scale, reverse_dilation=reverse_decoder_dilation, checkpoint_res=checkpoint_res), - nn.ConvTranspose1d(width, input_emb_width if i == (down_t - 1) else width, filter_t, stride_t, pad_t) - ) + Resnet1D( + width, + depth, + m_conv, + dilation_growth_rate, + dilation_cycle, + zero_out=zero_out, + res_scale=res_scale, + reverse_dilation=reverse_decoder_dilation, + checkpoint_res=checkpoint_res), + nn.ConvTranspose1d( + width, input_emb_width if i == (down_t - 1) else width, + filter_t, stride_t, pad_t)) blocks.append(block) self.model = nn.Sequential(*blocks) def forward(self, x): return self.model(x) + class Encoder(nn.Module): + def __init__(self, input_emb_width, output_emb_width, levels, downs_t, strides_t, **block_kwargs): super().__init__() @@ -64,10 +106,13 @@ def __init__(self, input_emb_width, output_emb_width, levels, downs_t, block_kwargs_copy = dict(**block_kwargs) if 'reverse_decoder_dilation' in block_kwargs_copy: del block_kwargs_copy['reverse_decoder_dilation'] - level_block = lambda level, down_t, stride_t: EncoderConvBlock(input_emb_width if level == 0 else output_emb_width, - output_emb_width, - down_t, stride_t, - **block_kwargs_copy) + level_block = lambda \ + level, down_t, stride_t: EncoderConvBlock( # noqa: E731 + input_emb_width if level == 0 else output_emb_width, + output_emb_width, + down_t, + stride_t, + **block_kwargs_copy) self.level_blocks = nn.ModuleList() iterator = zip(list(range(self.levels)), downs_t, strides_t) for level, down_t, stride_t in iterator: @@ -84,13 +129,15 @@ def forward(self, x): for level, down_t, stride_t in iterator: level_block = self.level_blocks[level] x = level_block(x) - emb, T = self.output_emb_width, T // (stride_t ** down_t) + emb, T = self.output_emb_width, T // (stride_t**down_t) assert_shape(x, (N, emb, T)) xs.append(x) return xs + class Decoder(nn.Module): + def __init__(self, input_emb_width, output_emb_width, levels, downs_t, strides_t, **block_kwargs): super().__init__() @@ -102,17 +149,22 @@ def __init__(self, input_emb_width, output_emb_width, levels, downs_t, self.strides_t = strides_t - level_block = lambda level, down_t, stride_t: DecoderConvBock(output_emb_width, - output_emb_width, - down_t, stride_t, - **block_kwargs) + level_block = lambda level, down_t, stride_t: \ + DecoderConvBock( # noqa: E731 + output_emb_width, + output_emb_width, + down_t, + stride_t, + **block_kwargs) self.level_blocks = nn.ModuleList() iterator = zip(list(range(self.levels)), downs_t, strides_t) for level, down_t, stride_t in iterator: self.level_blocks.append(level_block(level, down_t, stride_t)) self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1) - # self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1, padding_mode='replicate') + # self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1, + # padding_mode='replicate') + def forward(self, xs, all_levels=True): if all_levels: assert len(xs) == self.levels @@ -124,11 +176,12 @@ def forward(self, xs, all_levels=True): assert_shape(x, (N, emb, T)) # 32, 64 ... - iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t))) + iterator = reversed( + list(zip(list(range(self.levels)), self.downs_t, self.strides_t))) for level, down_t, stride_t in iterator: level_block = self.level_blocks[level] x = level_block(x) - emb, T = self.output_emb_width, T * (stride_t ** down_t) + emb, T = self.output_emb_width, T * (stride_t**down_t) assert_shape(x, (N, emb, T)) if level != 0 and all_levels: x = x + xs[level - 1] diff --git a/xrmogen/models/dance_models/bailando/vqvae/resnet.py b/xrmogen/models/dance_models/bailando/vqvae/resnet.py index e21a361..66c5c04 100644 --- a/xrmogen/models/dance_models/bailando/vqvae/resnet.py +++ b/xrmogen/models/dance_models/bailando/vqvae/resnet.py @@ -3,6 +3,7 @@ class ResConvBlock(nn.Module): + def __init__(self, n_in, n_state): super().__init__() self.model = nn.Sequential( @@ -15,24 +16,41 @@ def __init__(self, n_in, n_state): def forward(self, x): return x + self.model(x) + class Resnet(nn.Module): + def __init__(self, n_in, n_depth, m_conv=1.0): super().__init__() - self.model = nn.Sequential(*[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)]) + self.model = nn.Sequential( + *[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)]) def forward(self, x): return self.model(x) + class ResConv1DBlock(nn.Module): - def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0): + + def __init__(self, + n_in, + n_state, + dilation=1, + zero_out=False, + res_scale=1.0): super().__init__() padding = dilation self.model = nn.Sequential( nn.ReLU(), nn.Conv1d(n_in, n_state, 3, 1, padding, dilation), - # nn.Conv1d(n_in, n_state, 3, 1, padding, dilation, padding_mode='replicate'), + # nn.Conv1d(n_in, n_state, 3, 1, + # padding, dilation, padding_mode='replicate'), nn.ReLU(), - nn.Conv1d(n_state, n_in, 1, 1, 0,), + nn.Conv1d( + n_state, + n_in, + 1, + 1, + 0, + ), # nn.Conv1d(n_state, n_in, 1, 1, 0, padding_mode='replicate'), ) if zero_out: @@ -44,19 +62,36 @@ def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0): def forward(self, x): return x + self.res_scale * self.model(x) + class Resnet1D(nn.Module): - def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_dilation=False, checkpoint_res=False): + + def __init__(self, + n_in, + n_depth, + m_conv=1.0, + dilation_growth_rate=1, + dilation_cycle=None, + zero_out=False, + res_scale=False, + reverse_dilation=False, + checkpoint_res=False): super().__init__() + def _get_depth(depth): if dilation_cycle is None: return depth else: return depth % dilation_cycle - blocks = [ResConv1DBlock(n_in, int(m_conv * n_in), - dilation=dilation_growth_rate ** _get_depth(depth), - zero_out=zero_out, - res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth)) - for depth in range(n_depth)] + + blocks = [ + ResConv1DBlock( + n_in, + int(m_conv * n_in), + dilation=dilation_growth_rate**_get_depth(depth), + zero_out=zero_out, + res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth)) + for depth in range(n_depth) + ] if reverse_dilation: blocks = blocks[::-1] self.checkpoint_res = checkpoint_res diff --git a/xrmogen/models/dance_models/bailando/vqvae/sep_vqvae_root.py b/xrmogen/models/dance_models/bailando/vqvae/sep_vqvae_root.py index 8cc9b53..adcc093 100644 --- a/xrmogen/models/dance_models/bailando/vqvae/sep_vqvae_root.py +++ b/xrmogen/models/dance_models/bailando/vqvae/sep_vqvae_root.py @@ -1,29 +1,26 @@ -import numpy as np import torch import torch.nn as nn - from .vqvae import VQVAE from .vqvae_root import VQVAER - -smpl_down = [0, 1, 2, 4, 5, 7, 8, 10, 11] +smpl_down = [0, 1, 2, 4, 5, 7, 8, 10, 11] smpl_up = [3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + class SepVQVAER(nn.Module): + def __init__(self, hps): super().__init__() self.hps = hps - - self.chanel_num = hps.joint_channel - self.vqvae_up = VQVAE(hps.up_half, len(smpl_up)*self.chanel_num) - self.vqvae_down = VQVAER(hps.down_half, len(smpl_down)*self.chanel_num) + self.chanel_num = hps.joint_channel + self.vqvae_up = VQVAE(hps.up_half, len(smpl_up) * self.chanel_num) + self.vqvae_down = VQVAER(hps.down_half, + len(smpl_down) * self.chanel_num) def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): - """ - zs are list with two elements: z for up and z for down - """ + """zs are list with two elements: z for up and z for down.""" if isinstance(zs, tuple): zup = zs[0] zdown = zs[1] @@ -34,17 +31,25 @@ def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): xdown = self.vqvae_down.decode(zdown) b, t, cup = xup.size() _, _, cdown = xdown.size() - x = torch.zeros(b, t, (cup+cdown)//self.chanel_num, self.chanel_num).cuda() - x[:, :, smpl_up] = xup.view(b, t, cup//self.chanel_num, self.chanel_num) - x[:, :, smpl_down] = xdown.view(b, t, cdown//self.chanel_num, self.chanel_num) - - return x.view(b, t, -1) + x = torch.zeros(b, t, (cup + cdown) // self.chanel_num, + self.chanel_num).cuda() + x[:, :, smpl_up] = xup.view(b, t, cup // self.chanel_num, + self.chanel_num) + x[:, :, smpl_down] = xdown.view(b, t, cdown // self.chanel_num, + self.chanel_num) + return x.view(b, t, -1) def encode(self, x, start_level=0, end_level=None, bs_chunks=1): b, t, c = x.size() - zup = self.vqvae_up.encode(x.view(b, t, c//self.chanel_num, self.chanel_num)[:, :, smpl_up].view(b, t, -1), start_level, end_level, bs_chunks) - zdown = self.vqvae_down.encode(x.view(b, t, c//self.chanel_num, self.chanel_num)[:, :, smpl_down].view(b, t, -1), start_level, end_level, bs_chunks) + zup = self.vqvae_up.encode( + x.view(b, t, c // self.chanel_num, + self.chanel_num)[:, :, smpl_up].view(b, t, -1), start_level, + end_level, bs_chunks) + zdown = self.vqvae_down.encode( + x.view(b, t, c // self.chanel_num, + self.chanel_num)[:, :, smpl_down].view(b, t, -1), + start_level, end_level, bs_chunks) return (zup, zdown) def sample(self, n_samples): @@ -53,33 +58,36 @@ def sample(self, n_samples): xdown = self.vqvae_up.sample(n_samples) b, t, cup = xup.size() _, _, cdown = xdown.size() - x = torch.zeros(b, t, (cup+cdown)//self.chanel_num, self.chanel_num).cuda() - x[:, :, smpl_up] = xup.view(b, t, cup//self.chanel_num, self.chanel_num) - x[:, :, smpl_down] = xdown.view(b, t, cdown//self.chanel_num, self.chanel_num) + x = torch.zeros(b, t, (cup + cdown) // self.chanel_num, + self.chanel_num).cuda() + x[:, :, smpl_up] = xup.view(b, t, cup // self.chanel_num, + self.chanel_num) + x[:, :, smpl_down] = xdown.view(b, t, cdown // self.chanel_num, + self.chanel_num) return x def forward(self, x, phase='motion vqvae'): b, t, c = x.size() - x = x.view(b, t, c//self.chanel_num, self.chanel_num) + x = x.view(b, t, c // self.chanel_num, self.chanel_num) xup = x[:, :, smpl_up, :].view(b, t, -1) xdown = x[:, :, smpl_down, :].view(b, t, -1) - x_out_up, loss_up, metrics_up = self.vqvae_up(xup) - x_out_down , loss_down , metrics_down = self.vqvae_down(xdown, phase) + x_out_down, loss_down, metrics_down = self.vqvae_down(xdown, phase) _, _, cup = x_out_up.size() _, _, cdown = x_out_down.size() - xout = torch.zeros(b, t, (cup+cdown)//self.chanel_num, self.chanel_num).cuda().float() - xout[:, :, smpl_up] = x_out_up.view(b, t, cup//self.chanel_num, self.chanel_num) - xout[:, :, smpl_down] = x_out_down.view(b, t, cdown//self.chanel_num, self.chanel_num) - - + xout = torch.zeros(b, t, (cup + cdown) // self.chanel_num, + self.chanel_num).cuda().float() + xout[:, :, smpl_up] = x_out_up.view(b, t, cup // self.chanel_num, + self.chanel_num) + xout[:, :, smpl_down] = x_out_down.view(b, t, cdown // self.chanel_num, + self.chanel_num) if phase == 'motion vqvae': - return xout.view(b, t, -1), 0.5*(loss_down + loss_up), None + return xout.view(b, t, -1), 0.5 * (loss_down + loss_up), None else: metrics_up['acceleration_loss'] *= 0 - metrics_up['velocity_loss'] *= 0 + metrics_up['velocity_loss'] *= 0 return xout.view(b, t, -1), loss_down, [metrics_up, metrics_down] diff --git a/xrmogen/models/dance_models/bailando/vqvae/vqvae.py b/xrmogen/models/dance_models/bailando/vqvae/vqvae.py index d8fa963..b730815 100644 --- a/xrmogen/models/dance_models/bailando/vqvae/vqvae.py +++ b/xrmogen/models/dance_models/bailando/vqvae/vqvae.py @@ -2,27 +2,30 @@ import torch as t import torch.nn as nn -from .encdec import Encoder, Decoder, assert_shape -from .bottleneck import NoBottleneck, Bottleneck - +from .bottleneck import Bottleneck, NoBottleneck +from .encdec import Decoder, Encoder, assert_shape def dont_update(params): for param in params: param.requires_grad = False + def update(params): for param in params: param.requires_grad = True + def calculate_strides(strides, downs): - return [stride ** down for stride, down in zip(strides, downs)] + return [stride**down for stride, down in zip(strides, downs)] def _loss_fn(x_target, x_pred): - return t.mean(t.abs(x_pred - x_target)) + return t.mean(t.abs(x_pred - x_target)) + class VQVAE(nn.Module): + def __init__(self, hps, input_dim=72): super().__init__() self.hps = hps @@ -36,7 +39,7 @@ def __init__(self, hps, input_dim=72): mu = hps.l_mu commit = hps.commit - multipliers = hps.hvqvae_multipliers + multipliers = hps.hvqvae_multipliers use_bottleneck = hps.use_bottleneck if use_bottleneck: print('We use bottleneck!') @@ -44,10 +47,13 @@ def __init__(self, hps, input_dim=72): print('We do not use bottleneck!') if not hasattr(hps, 'dilation_cycle'): hps.dilation_cycle = None - block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ - dilation_growth_rate=hps.dilation_growth_rate, \ - dilation_cycle=hps.dilation_cycle, \ - reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) + block_kwargs = dict( + width=hps.width, + depth=hps.depth, + m_conv=hps.m_conv, + dilation_growth_rate=hps.dilation_growth_rate, + dilation_cycle=hps.dilation_cycle, + reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) self.sample_length = input_shape[0] x_shape, x_channels = input_shape[:-1], input_shape[-1] @@ -55,24 +61,28 @@ def __init__(self, hps, input_dim=72): self.downsamples = calculate_strides(strides_t, downs_t) self.hop_lengths = np.cumprod(self.downsamples) - self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] + self.z_shapes = [(x_shape[0] // self.hop_lengths[level], ) + for level in range(levels)] self.levels = levels if multipliers is None: self.multipliers = [1] * levels else: - assert len(multipliers) == levels, "Invalid number of multipliers" + assert len(multipliers) == levels, 'Invalid number of multipliers' self.multipliers = multipliers + def _block_kwargs(level): this_block_kwargs = dict(block_kwargs) - this_block_kwargs["width"] *= self.multipliers[level] - this_block_kwargs["depth"] *= self.multipliers[level] + this_block_kwargs['width'] *= self.multipliers[level] + this_block_kwargs['depth'] *= self.multipliers[level] return this_block_kwargs - encoder = lambda level: Encoder(x_channels, emb_width, level + 1, - downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) - decoder = lambda level: Decoder(x_channels, emb_width, level + 1, - downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) + encoder = lambda level: Encoder( # noqa: E731 + x_channels, emb_width, level + 1, downs_t[:level + 1], + strides_t[:level + 1], **_block_kwargs(level)) + decoder = lambda level: Decoder( # noqa: E731 + x_channels, emb_width, level + 1, downs_t[:level + 1], + strides_t[:level + 1], **_block_kwargs(level)) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() for level in range(levels): @@ -91,18 +101,18 @@ def _block_kwargs(level): self.reg = hps.reg if hasattr(hps, 'reg') else 0 self.acc = hps.acc if hasattr(hps, 'acc') else 0 self.vel = hps.vel if hasattr(hps, 'vel') else 0 - if self.reg is 0: + if self.reg == 0: print('No motion regularization!') def preprocess(self, x): # x: NTC [-1,1] -> NCT [-1,1] assert len(x.shape) == 3 - x = x.permute(0,2,1).float() + x = x.permute(0, 2, 1).float() return x def postprocess(self, x): # x: NTC [-1,1] <- NCT [-1,1] - x = x.permute(0,2,1) + x = x.permute(0, 2, 1) return x def _decode(self, zs, start_level=0, end_level=None): @@ -110,7 +120,8 @@ def _decode(self, zs, start_level=0, end_level=None): if end_level is None: end_level = self.levels assert len(zs) == end_level - start_level - xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level) + xs_quantised = self.bottleneck.decode( + zs, start_level=start_level, end_level=end_level) assert len(xs_quantised) == end_level - start_level # Use only lowest level @@ -124,7 +135,8 @@ def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): x_outs = [] for i in range(bs_chunks): zs_i = [z_chunk[i] for z_chunk in z_chunks] - x_out = self._decode(zs_i, start_level=start_level, end_level=end_level) + x_out = self._decode( + zs_i, start_level=start_level, end_level=end_level) x_outs.append(x_out) return t.cat(x_outs, dim=0) @@ -145,20 +157,23 @@ def encode(self, x, start_level=0, end_level=None, bs_chunks=1): x_chunks = t.chunk(x, bs_chunks, dim=0) zs_list = [] for x_i in x_chunks: - zs_i = self._encode(x_i, start_level=start_level, end_level=end_level) + zs_i = self._encode( + x_i, start_level=start_level, end_level=end_level) zs_list.append(zs_i) zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)] return zs def sample(self, n_samples): - zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device='cuda') for z_shape in self.z_shapes] + zs = [ + t.randint( + 0, self.l_bins, size=(n_samples, *z_shape), device='cuda') + for z_shape in self.z_shapes + ] return self.decode(zs) def forward(self, x): metrics = {} - N = x.shape[0] - # Encode/Decode x_in = self.preprocess(x) xs = [] @@ -167,15 +182,15 @@ def forward(self, x): x_out = encoder(x_in) xs.append(x_out[-1]) - zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) + zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck( + xs) x_outs = [] for level in range(self.levels): decoder = self.decoders[level] - x_out = decoder(xs_quantised[level:level+1], all_levels=False) + x_out = decoder(xs_quantised[level:level + 1], all_levels=False) assert_shape(x_out, x_in.shape) x_outs.append(x_out) - recons_loss = t.zeros(()).to(x.device) regularization = t.zeros(()).to(x.device) velocity_loss = t.zeros(()).to(x.device) @@ -191,25 +206,29 @@ def forward(self, x): recons_loss += this_recons_loss - velocity_loss += _loss_fn( x_out[:, 1:] - x_out[:, :-1], x_target[:, 1:] - x_target[:, :-1]) - acceleration_loss += _loss_fn(x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1], x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) - + velocity_loss += _loss_fn(x_out[:, 1:] - x_out[:, :-1], + x_target[:, 1:] - x_target[:, :-1]) + acceleration_loss += _loss_fn( + x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1], + x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) + # this loss can not be split from the model due to commit_loss commit_loss = sum(commit_losses) - loss = recons_loss + commit_loss * self.commit + self.vel * velocity_loss + self.acc * acceleration_loss + loss = recons_loss + commit_loss * self.commit + \ + self.vel * velocity_loss + \ + self.acc * acceleration_loss with t.no_grad(): l1_loss = _loss_fn(x_target, x_out) - - - metrics.update(dict( - recons_loss=recons_loss, - l1_loss=l1_loss, - commit_loss=commit_loss, - regularization=regularization, - velocity_loss=velocity_loss, - acceleration_loss=acceleration_loss)) + metrics.update( + dict( + recons_loss=recons_loss, + l1_loss=l1_loss, + commit_loss=commit_loss, + regularization=regularization, + velocity_loss=velocity_loss, + acceleration_loss=acceleration_loss)) for key, val in metrics.items(): metrics[key] = val.detach() diff --git a/xrmogen/models/dance_models/bailando/vqvae/vqvae_root.py b/xrmogen/models/dance_models/bailando/vqvae/vqvae_root.py index 946f345..d6754be 100644 --- a/xrmogen/models/dance_models/bailando/vqvae/vqvae_root.py +++ b/xrmogen/models/dance_models/bailando/vqvae/vqvae_root.py @@ -2,25 +2,30 @@ import torch as t import torch.nn as nn -from .encdec import Encoder, Decoder, assert_shape -from .bottleneck import NoBottleneck, Bottleneck +from .bottleneck import Bottleneck, NoBottleneck +from .encdec import Decoder, Encoder, assert_shape def dont_update(params): for param in params: param.requires_grad = False + def update(params): for param in params: param.requires_grad = True + def calculate_strides(strides, downs): - return [stride ** down for stride, down in zip(strides, downs)] + return [stride**down for stride, down in zip(strides, downs)] + def _loss_fn(x_target, x_pred): - return t.mean(t.abs(x_pred - x_target)) + return t.mean(t.abs(x_pred - x_target)) + class VQVAER(nn.Module): + def __init__(self, hps, input_dim=72): super().__init__() self.hps = hps @@ -34,7 +39,7 @@ def __init__(self, hps, input_dim=72): mu = hps.l_mu commit = hps.commit - multipliers = hps.hvqvae_multipliers + multipliers = hps.hvqvae_multipliers use_bottleneck = hps.use_bottleneck if use_bottleneck: print('We use bottleneck!') @@ -42,10 +47,13 @@ def __init__(self, hps, input_dim=72): print('We do not use bottleneck!') if not hasattr(hps, 'dilation_cycle'): hps.dilation_cycle = None - block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ - dilation_growth_rate=hps.dilation_growth_rate, \ - dilation_cycle=hps.dilation_cycle, \ - reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) + block_kwargs = dict( + width=hps.width, + depth=hps.depth, + m_conv=hps.m_conv, + dilation_growth_rate=hps.dilation_growth_rate, + dilation_cycle=hps.dilation_cycle, + reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) self.sample_length = input_shape[0] x_shape, x_channels = input_shape[:-1], input_shape[-1] @@ -53,26 +61,31 @@ def __init__(self, hps, input_dim=72): self.downsamples = calculate_strides(strides_t, downs_t) self.hop_lengths = np.cumprod(self.downsamples) - self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] + self.z_shapes = [(x_shape[0] // self.hop_lengths[level], ) + for level in range(levels)] self.levels = levels if multipliers is None: self.multipliers = [1] * levels else: - assert len(multipliers) == levels, "Invalid number of multipliers" + assert len(multipliers) == levels, 'Invalid number of multipliers' self.multipliers = multipliers + def _block_kwargs(level): this_block_kwargs = dict(block_kwargs) - this_block_kwargs["width"] *= self.multipliers[level] - this_block_kwargs["depth"] *= self.multipliers[level] + this_block_kwargs['width'] *= self.multipliers[level] + this_block_kwargs['depth'] *= self.multipliers[level] return this_block_kwargs - encoder = lambda level: Encoder(x_channels, emb_width, level + 1, - downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) - decoder = lambda level: Decoder(x_channels, emb_width, level + 1, - downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) - decoder_root = lambda level: Decoder(hps.joint_channel, emb_width, level + 1, - downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) + encoder = lambda level: Encoder( # noqa: E731 + x_channels, emb_width, level + 1, downs_t[:level + 1], + strides_t[:level + 1], **_block_kwargs(level)) + decoder = lambda level: Decoder( # noqa: E731 + x_channels, emb_width, level + 1, downs_t[:level + 1], + strides_t[:level + 1], **_block_kwargs(level)) + decoder_root = lambda level: Decoder( # noqa: E731 + hps.joint_channel, emb_width, level + 1, downs_t[:level + 1], + strides_t[:level + 1], **_block_kwargs(level)) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() self.decoders_root = nn.ModuleList() @@ -93,18 +106,18 @@ def _block_kwargs(level): self.reg = hps.reg if hasattr(hps, 'reg') else 0 self.acc = hps.acc if hasattr(hps, 'acc') else 0 self.vel = hps.vel if hasattr(hps, 'vel') else 0 - if self.reg is 0: + if self.reg == 0: print('No motion regularization!') def preprocess(self, x): # x: NTC [-1,1] -> NCT [-1,1] assert len(x.shape) == 3 - x = x.permute(0,2,1).float() + x = x.permute(0, 2, 1).float() return x def postprocess(self, x): # x: NTC [-1,1] <- NCT [-1,1] - x = x.permute(0,2,1) + x = x.permute(0, 2, 1) return x def _decode(self, zs, start_level=0, end_level=None): @@ -112,17 +125,19 @@ def _decode(self, zs, start_level=0, end_level=None): if end_level is None: end_level = self.levels assert len(zs) == end_level - start_level - xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level) + xs_quantised = self.bottleneck.decode( + zs, start_level=start_level, end_level=end_level) assert len(xs_quantised) == end_level - start_level # Use only lowest level - decoder, decoder_root, x_quantised = self.decoders[start_level], self.decoders_root[start_level], xs_quantised[0:1] + decoder, decoder_root, x_quantised = self.decoders[ + start_level], self.decoders_root[start_level], xs_quantised[0:1] x_out = decoder(x_quantised, all_levels=False) x_vel_out = decoder_root(x_quantised, all_levels=False) x_out = self.postprocess(x_out) x_vel_out = self.postprocess(x_vel_out) - + _, _, cc = x_vel_out.size() x_out[:, :, :cc] = x_vel_out.clone() return x_out @@ -132,7 +147,8 @@ def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): x_outs = [] for i in range(bs_chunks): zs_i = [z_chunk[i] for z_chunk in z_chunks] - x_out = self._decode(zs_i, start_level=start_level, end_level=end_level) + x_out = self._decode( + zs_i, start_level=start_level, end_level=end_level) x_outs.append(x_out) return t.cat(x_outs, dim=0) @@ -154,25 +170,28 @@ def encode(self, x, start_level=0, end_level=None, bs_chunks=1): x_chunks = t.chunk(x, bs_chunks, dim=0) zs_list = [] for x_i in x_chunks: - zs_i = self._encode(x_i, start_level=start_level, end_level=end_level) + zs_i = self._encode( + x_i, start_level=start_level, end_level=end_level) zs_list.append(zs_i) zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)] return zs def sample(self, n_samples): - zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device='cuda') for z_shape in self.z_shapes] + zs = [ + t.randint( + 0, self.l_bins, size=(n_samples, *z_shape), device='cuda') + for z_shape in self.z_shapes + ] return self.decode(zs) def forward(self, x, phase='motion vqvae'): - + if phase == 'global velocity': self.bottleneck.eval() with t.no_grad(): metrics = {} - N = x.shape[0] - x_zero = x.clone() x_zero[:, :, :self.hps.joint_channel] = 0 @@ -186,29 +205,31 @@ def forward(self, x, phase='motion vqvae'): x_out = encoder(x_in) xs.append(x_out[-1]) - zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) + zs, xs_quantised, commit_losses, _ = \ + self.bottleneck( + xs) x_outs = [] x_outs_vel = [] - + for level in range(self.levels): decoder = self.decoders[level] if phase == 'global velocity': decoder.eval() decoder_root = self.decoders_root[level] - x_out = decoder(xs_quantised[level:level+1], all_levels=False) - x_vel_out = decoder_root(xs_quantised[level:level+1], all_levels=False) + x_out = decoder(xs_quantised[level:level + 1], all_levels=False) + x_vel_out = decoder_root( + xs_quantised[level:level + 1], all_levels=False) assert_shape(x_out, x_in.shape) x_outs.append(x_out) x_outs_vel.append(x_vel_out) - recons_loss = t.zeros(()).to(x.device) - regularization = t.zeros(()).to(x.device) velocity_loss = t.zeros(()).to(x.device) acceleration_loss = t.zeros(()).to(x.device) - x_target = x_zero if phase=='motion vqvae' else x.float()[:, :, :self.hps.joint_channel] + x_target = x_zero if phase == 'motion vqvae' else x.float( + )[:, :, :self.hps.joint_channel] for level in reversed(range(self.levels)): x_out_vel = self.postprocess(x_outs_vel[level]) @@ -218,37 +239,48 @@ def forward(self, x, phase='motion vqvae'): x_out[:, :, :cc] = x_out_vel if phase == 'motion vqvae': - this_recons_loss = _loss_fn(x_target, x_out_zero) - this_velocity_loss = _loss_fn( x_out_zero[:, 1:] - x_out_zero[:, :-1], x_target[:, 1:] - x_target[:, :-1]) - this_acceleration_loss = _loss_fn(x_out_zero[:, 2:] + x_out_zero[:, :-2] - 2 * x_out_zero[:, 1:-1], x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) + this_recons_loss = _loss_fn(x_target, x_out_zero) + this_velocity_loss = _loss_fn( + x_out_zero[:, 1:] - x_out_zero[:, :-1], + x_target[:, 1:] - x_target[:, :-1]) + this_acceleration_loss = _loss_fn( + x_out_zero[:, 2:] + x_out_zero[:, :-2] - + 2 * x_out_zero[:, 1:-1], + x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) else: - this_recons_loss =_loss_fn(x_target, x_out_vel) + this_recons_loss = _loss_fn(x_target, x_out_vel) this_velocity_loss = 0 - this_acceleration_loss = _loss_fn( x_out_vel[:, 1:] - x_out_vel[:, :-1], x_target[:, 1:] - x_target[:, :-1]) + this_acceleration_loss = _loss_fn( + x_out_vel[:, 1:] - x_out_vel[:, :-1], + x_target[:, 1:] - x_target[:, :-1]) metrics[f'recons_loss_l{level + 1}'] = this_recons_loss recons_loss += this_recons_loss velocity_loss += this_velocity_loss - acceleration_loss += this_acceleration_loss - + acceleration_loss += this_acceleration_loss + if phase == 'motion vqvae': - # this loss can not be split from the model due to commit_loss + # this loss can not be split from the model due to commit_loss commit_loss = sum(commit_losses) - loss = recons_loss + commit_loss * self.commit + self.vel * velocity_loss + self.acc * acceleration_loss + loss = recons_loss + \ + commit_loss * self.commit + \ + self.vel * velocity_loss + \ + self.acc * acceleration_loss else: loss = recons_loss + self.acc * acceleration_loss with t.no_grad(): - l1_loss = _loss_fn(x_target, x_out_zero) if phase == 'motion vqvae' else _loss_fn(x_target, x_out_vel) - - - metrics.update(dict( - recons_loss=recons_loss, - l1_loss=l1_loss, - velocity_loss=l1_loss, - acceleration_loss=acceleration_loss - )) + l1_loss = _loss_fn( + x_target, x_out_zero) if phase == 'motion vqvae' else _loss_fn( + x_target, x_out_vel) + + metrics.update( + dict( + recons_loss=recons_loss, + l1_loss=l1_loss, + velocity_loss=l1_loss, + acceleration_loss=acceleration_loss)) for key, val in metrics.items(): metrics[key] = val.detach() diff --git a/xrmogen/models/dance_models/dancerev/layers.py b/xrmogen/models/dance_models/dancerev/layers.py index 4909767..1c4250d 100644 --- a/xrmogen/models/dance_models/dancerev/layers.py +++ b/xrmogen/models/dance_models/dancerev/layers.py @@ -1,108 +1,116 @@ -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this open-source project. - - -""" Define the attention layers. """ -import torch -import numpy as np -import torch.nn as nn -import torch.nn.functional as F - - -class ScaledDotProductAttention(nn.Module): - """ Scaled Dot-Product Attention """ - def __init__(self, temperature, attn_dropout=0.1): - super().__init__() - self.temperature = temperature - self.dropout = nn.Dropout(attn_dropout) - self.softmax = nn.Softmax(dim=2) - - def forward(self, q, k, v, mask=None): - attn = torch.bmm(q, k.transpose(1, 2)) - attn = attn / self.temperature - - if mask is not None: - batch_size, _, _ = attn.size() - mask = mask.unsqueeze(0).expand(batch_size, -1, -1) - attn = attn.masked_fill(mask, -np.inf) - - attn = self.softmax(attn) - attn = self.dropout(attn) - output = torch.bmm(attn, v) - - return output, attn - - -class MultiHeadAttention(nn.Module): - """ Multi-Head Attention module """ - - def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): - super().__init__() - - self.n_head = n_head - self.d_k = d_k - self.d_v = d_v - - self.w_qs = nn.Linear(d_model, n_head * d_k) - self.w_ks = nn.Linear(d_model, n_head * d_k) - self.w_vs = nn.Linear(d_model, n_head * d_v) - nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) - nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) - nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) - - self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) - self.layer_norm = nn.LayerNorm(d_model) - - self.fc = nn.Linear(n_head * d_v, d_model) - nn.init.xavier_normal_(self.fc.weight) - - self.dropout = nn.Dropout(dropout) - - def forward(self, q, k, v, mask=None): - - d_k, d_v, n_head = self.d_k, self.d_v, self.n_head - - sz_b, len_q, _ = q.size() - sz_b, len_k, _ = k.size() - sz_b, len_v, _ = v.size() - - residual = q - - q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) - k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) - v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) - - q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk - k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk - v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv - - # mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. - output, attn = self.attention(q, k, v, mask=mask) - - output = output.view(n_head, sz_b, len_q, d_v) - output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) - - output = self.dropout(self.fc(output)) - output = self.layer_norm(output + residual) - - return output, attn - - -class PositionwiseFeedForward(nn.Module): - """ A two-feed-forward-layer module """ - - def __init__(self, d_in, d_hid, dropout=0.1): - super().__init__() - self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise - self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise - self.layer_norm = nn.LayerNorm(d_in) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - residual = x - output = x.transpose(1, 2) - output = self.w_2(F.relu(self.w_1(output))) - output = output.transpose(1, 2) - output = self.dropout(output) - output = self.layer_norm(output + residual) - return output +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this open-source project. +"""Define the attention layers.""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ScaledDotProductAttention(nn.Module): + """Scaled Dot-Product Attention.""" + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + self.softmax = nn.Softmax(dim=2) + + def forward(self, q, k, v, mask=None): + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + if mask is not None: + batch_size, _, _ = attn.size() + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + attn = attn.masked_fill(mask, -np.inf) + + attn = self.softmax(attn) + attn = self.dropout(attn) + output = torch.bmm(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module.""" + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + nn.init.normal_( + self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_( + self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_( + self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) + + self.attention = ScaledDotProductAttention( + temperature=np.power(d_k, 0.5)) + self.layer_norm = nn.LayerNorm(d_model) + + self.fc = nn.Linear(n_head * d_v, d_model) + nn.init.xavier_normal_(self.fc.weight) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + + sz_b, len_q, _ = q.size() + sz_b, len_k, _ = k.size() + sz_b, len_v, _ = v.size() + + residual = q + + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, + d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, + d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, + d_v) # (n*b) x lv x dv + + # mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + output, attn = self.attention(q, k, v, mask=mask) + + output = output.view(n_head, sz_b, len_q, d_v) + output = output.permute(1, 2, 0, + 3).contiguous().view(sz_b, len_q, + -1) # b x lq x (n*dv) + + output = self.dropout(self.fc(output)) + output = self.layer_norm(output + residual) + + return output, attn + + +class PositionwiseFeedForward(nn.Module): + """A two-feed-forward-layer module.""" + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise + self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise + self.layer_norm = nn.LayerNorm(d_in) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + output = x.transpose(1, 2) + output = self.w_2(F.relu(self.w_1(output))) + output = output.transpose(1, 2) + output = self.dropout(output) + output = self.layer_norm(output + residual) + return output diff --git a/xrmogen/models/dance_models/dancerev/model.py b/xrmogen/models/dance_models/dancerev/model.py index 29c577c..aa9dae3 100644 --- a/xrmogen/models/dance_models/dancerev/model.py +++ b/xrmogen/models/dance_models/dancerev/model.py @@ -1,347 +1,395 @@ -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this open-source project. - - -""" Define the Seq2Seq Generation Network """ -from os import device_encoding -import this -import numpy as np -import torch -import torch.nn as nn - -from .layers import MultiHeadAttention, PositionwiseFeedForward -from ...builder import DANCE_MODELS - - -BOS_POSE_AIST = np.array([ - 0.01340632513165474, 1.6259130239486694, -0.09833218157291412, 0.0707249641418457, 1.5451008081436157, -0.12474726885557175, -0.04773886129260063, 1.536355972290039, -0.11427298933267593, 0.015812935307621956, 1.7525817155838013, -0.12864114344120026, 0.13902147114276886, 1.1639258861541748, -0.0879698246717453, -0.10036090016365051, 1.1553057432174683, -0.08047012239694595, 0.006522613577544689, 1.8904004096984863, -0.10235153883695602, 0.07891514897346497, 0.7553867101669312, -0.20340093970298767, -0.037818294018507004, 0.7545002698898315, -0.1963980495929718, 0.00045378319919109344, 1.9454832077026367, -0.09329807013273239, 0.11616306006908417, 0.668250560760498, -0.0974099189043045, -0.05322670564055443, 0.6652328968048096, -0.07871627062559128, -0.014527007937431335, 2.159270763397217, -0.08067376166582108, 0.0712718814611435, 2.0614874362945557, -0.08859370648860931, -0.08343493938446045, 2.0597264766693115, -0.09117652475833893, -0.002253010869026184, 2.244560718536377, -0.024742677807807922, 0.19795098900794983, 2.098480463027954, -0.09858542680740356, -0.20080527663230896, 2.0911219120025635, -0.0731159895658493, 0.30632710456848145, 1.8656978607177734, -0.09286995232105255, -0.3086402714252472, 1.8520605564117432, -0.06464222073554993, 0.25927090644836426, 1.632638931274414, 0.02665536105632782, -0.2640104591846466, 1.6051883697509766, 0.0331537127494812, 0.2306937426328659, 1.5523173809051514, 0.051218822598457336, -0.24223697185516357, 1.5211939811706543, 0.05606864392757416 -]) - -BOS_POSE_AIST_ROT = np.concatenate([ - np.array([0.05072649, 1.87570345, -0.24885127]), # root - np.concatenate([np.eye(3).reshape(-1)] * 24) -]) - -def get_non_pad_mask(seq): - assert seq.dim() == 3 - non_pad_mask = torch.abs(seq).sum(2).ne(0).type(torch.float) - return non_pad_mask.unsqueeze(-1) - - -def get_attn_key_pad_mask(seq_k, seq_q): - """ For masking out the padding part of key sequence. """ - len_q = seq_q.size(1) - padding_mask = torch.abs(seq_k).sum(2).eq(0) # sum the vector of last dim and then judge - padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk - - return padding_mask - - -def get_subsequent_mask(seq, sliding_windown_size): - """ For masking out the subsequent info. """ - batch_size, seq_len, _ = seq.size() - mask = torch.ones((seq_len, seq_len), device=seq.device, dtype=torch.uint8) - - mask = torch.triu(mask, diagonal=-sliding_windown_size) - mask = torch.tril(mask, diagonal=sliding_windown_size) - mask = 1 - mask - # print(mask) - return mask.bool() - - -def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): - """ Sinusoid position encoding table """ - def cal_angle(position, hid_idx): - return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) - - def get_posi_angle_vec(position): - return [cal_angle(position, hid_j) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) - - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - if padding_idx is not None: - # zero vector for padding dimension - sinusoid_table[padding_idx] = 0. - - return torch.FloatTensor(sinusoid_table) - - -class EncoderLayer(nn.Module): - """ Compose with two layers """ - - def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): - super(EncoderLayer, self).__init__() - self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) - self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) - - def forward(self, enc_input, slf_attn_mask=None, non_pad_mask=None): - - enc_output, enc_slf_attn = self.slf_attn( - enc_input, enc_input, enc_input, mask=slf_attn_mask) - # enc_output *= non_pad_mask - - enc_output = self.pos_ffn(enc_output) - # enc_output *= non_pad_mask - - return enc_output, enc_slf_attn - - -class Encoder(nn.Module): - """ A encoder model with self attention mechanism. """ - - def __init__( - self, max_seq_len=1800, input_size=20, d_word_vec=10, - n_layers=6, n_head=8, d_k=64, d_v=64, - d_model=10, d_inner=256, dropout=0.1): - - super().__init__() - - self.d_model = d_model - n_position = max_seq_len + 1 - - self.src_emb = nn.Linear(input_size, d_word_vec) - - self.position_enc = nn.Embedding.from_pretrained( - get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), - freeze=True) - - self.layer_stack = nn.ModuleList([ - EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) - for _ in range(n_layers)]) - - def forward(self, src_seq, src_pos, mask=None, return_attns=False): - - enc_slf_attn_list = [] - - # -- Forward - enc_output = self.src_emb(src_seq) + self.position_enc(src_pos).float() - - for enc_layer in self.layer_stack: - enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=mask) - - if return_attns: - enc_slf_attn_list += [enc_slf_attn] - - if return_attns: - return enc_output, enc_slf_attn_list - return enc_output, - - -class Decoder(nn.Module): - def __init__(self, input_size=274, d_word_vec=150, hidden_size=200, - dropout=0.1, encoder_d_model=200, rotmat=False): - super().__init__() - - self.input_size = input_size - self.d_word_vec = d_word_vec - self.hidden_size = hidden_size - - self.tgt_emb = nn.Linear(input_size, d_word_vec) - self.dropout = nn.Dropout(dropout) - self.encoder_d_model = encoder_d_model - - self.lstm1 = nn.LSTMCell(d_word_vec, hidden_size) - self.lstm2 = nn.LSTMCell(hidden_size, hidden_size) - self.lstm3 = nn.LSTMCell(hidden_size, hidden_size) - self.rotmat = rotmat - - def init_state(self, bsz, device): - c0 = torch.randn(bsz, self.hidden_size).to(device) - c1 = torch.randn(bsz, self.hidden_size).to(device) - c2 = torch.randn(bsz, self.hidden_size).to(device) - h0 = torch.randn(bsz, self.hidden_size).to(device) - h1 = torch.randn(bsz, self.hidden_size).to(device) - h2 = torch.randn(bsz, self.hidden_size).to(device) - - vec_h = [h0, h1, h2] - vec_c = [c0, c1, c2] - - if self.rotmat: - bos = BOS_POSE_AIST_ROT - bos = np.tile(bos, (bsz, 1)) - else: - bos = BOS_POSE_AIST - bos = np.tile(bos, (bsz, 1)) - root = bos[:, :3] - bos = bos - np.tile(root, (1, 24)) - bos[:, :3] = root - out_frame = torch.from_numpy(bos).float().to(device) - out_seq = torch.FloatTensor(bsz, 1).to(device) - - return (vec_h, vec_c), out_frame, out_seq - - def forward(self, in_frame, vec_h, vec_c): - - in_frame = self.tgt_emb(in_frame) - in_frame = self.dropout(in_frame) - - vec_h0, vec_c0 = self.lstm1(in_frame, (vec_h[0], vec_c[0])) - vec_h1, vec_c1 = self.lstm2(vec_h[0], (vec_h[1], vec_c[1])) - vec_h2, vec_c2 = self.lstm3(vec_h[1], (vec_h[2], vec_c[2])) - - vec_h_new = [vec_h0, vec_h1, vec_h2] - vec_c_new = [vec_c0, vec_c1, vec_c2] - return vec_h2, vec_h_new, vec_c_new - -@DANCE_MODELS.register_module() -class DanceRevolution(nn.Module): - def __init__(self, model_config): - super().__init__() - args = model_config - self.model_args = args - self._epoch = None - - encoder = Encoder(max_seq_len=args.max_seq_len, - input_size=args.d_frame_vec, - d_word_vec=args.frame_emb_size, - n_layers=args.n_layers, - n_head=args.n_head, - d_k=args.d_k, - d_v=args.d_v, - d_model=args.d_model, - d_inner=args.d_inner, - dropout=args.dropout) - - decoder = Decoder(input_size=args.d_pose_vec, - d_word_vec=args.pose_emb_size, - hidden_size=args.d_inner, - encoder_d_model=args.d_model, - dropout=args.dropout, - rotmat=args.rotmat - ) - - - condition_step=args.condition_step - sliding_windown_size=args.sliding_windown_size - lambda_v=args.lambda_v - - - self.encoder = encoder - self.decoder = decoder - self.linear = nn.Linear(decoder.hidden_size + encoder.d_model, decoder.input_size) - - self.condition_step = condition_step - self.sliding_windown_size = sliding_windown_size - self.lambda_v = lambda_v - device = torch.device('cuda' if args.cuda else 'cpu') - self.device = device - - - def init_decoder_hidden(self, bsz): - return self.decoder.init_state(bsz, self.device) - - # dynamic auto-condition + self-attention mask - def forward(self, src_seq, tgt_seq, epoch_i): - bsz, seq_len, _ = tgt_seq.size() - src_pos = (torch.arange(seq_len).long() + 1)[None].expand(bsz, -1).detach().to(self.device) - hidden, dec_output, out_seq = self.init_decoder_hidden(tgt_seq.size(0)) - # forward - - vec_h, vec_c = hidden - - enc_mask = get_subsequent_mask(src_seq, self.sliding_windown_size) - enc_outputs, *_ = self.encoder(src_seq, src_pos, mask=enc_mask) - - groundtruth_mask = torch.ones(seq_len, self.condition_step) - prediction_mask = torch.zeros(seq_len, int(epoch_i * self.lambda_v)) - mask = torch.cat([prediction_mask, groundtruth_mask], 1).view(-1)[:seq_len] # for random - - preds = [] - for i in range(seq_len): - dec_input = tgt_seq[:, i] if mask[i] == 1 else dec_output.detach() # dec_output - dec_output, vec_h, vec_c = self.decoder(dec_input, vec_h, vec_c) - dec_output = torch.cat([dec_output, enc_outputs[:, i]], 1) - dec_output = self.linear(dec_output) - preds.append(dec_output) - - outputs = [z.unsqueeze(1) for z in preds] - outputs = torch.cat(outputs, dim=1) - return outputs - - def generate(self, src_seq,): - """ Generate dance pose in one batch """ - with torch.no_grad(): - # Use the pre-defined begin of pose (BOP) to generate whole sequence - bsz, src_seq_len, _ = src_seq.size() - src_pos = (torch.arange(src_seq_len).long() + 1)[None].expand(bsz, -1).detach().to(self.device) - # bsz, tgt_seq_len, dim = tgt_seq.size() - tgt_seq_len = 1 - generated_frames_num = src_seq_len - tgt_seq_len - - hidden, dec_output, out_seq = self.init_decoder_hidden(bsz) - vec_h, vec_c = hidden - - enc_mask = get_subsequent_mask(src_seq, self.model_args.sliding_windown_size) - enc_outputs, *_ = self.encoder(src_seq, src_pos, enc_mask) - - preds = [] - for i in range(tgt_seq_len): - # dec_input = tgt_seq[:, i] - dec_input = dec_output - dec_output, vec_h, vec_c = self.decoder(dec_input, vec_h, vec_c) - dec_output = torch.cat([dec_output, enc_outputs[:, i]], 1) - dec_output = self.linear(dec_output) - preds.append(dec_output) - - for i in range(generated_frames_num): - dec_input = dec_output - dec_output, vec_h, vec_c = self.decoder(dec_input, vec_h, vec_c) - dec_output = torch.cat([dec_output, enc_outputs[:, i + tgt_seq_len]], 1) - dec_output = self.linear(dec_output) - preds.append(dec_output) - - outputs = [z.unsqueeze(1) for z in preds] - outputs = torch.cat(outputs, dim=1) - return outputs - - def train_step(self, data, optimizer, **kwargs): - self.encoder.train() - self.decoder.train() - self.linear.train() - - aud_seq, pose_seq = data['music'], data['dance'] - - gold_seq = pose_seq[:, 1:] - src_aud = aud_seq[:, :-1] - src_pos = pose_seq[:, :-1] - - optimizer.zero_grad() - - output = self.forward(src_aud, src_pos, self._epoch) - - loss = torch.nn.functional.mse_loss(output, gold_seq) - - stats = { - 'loss': loss.item() - } - - outputs = { - 'loss': loss, - 'log_vars': stats, - 'num_samples': output.size(1) - } - - return outputs - - def val_step(self, data, optimizer, **kwargs): - return self.test_step(data, optimizer, **kwargs) - - def test_step(self, data, optimizer, **kwargs): - results = [] - self.eval() - with torch.no_grad(): - aud_seq_eval, pose_seq_eval = data['music'], data['dance'] - - pose_seq_out = self.generate(aud_seq_eval) - results.append(pose_seq_out) - outputs = { - 'output_pose': results[0], - 'file_name': data['file_names'][0] - } - return outputs - - - +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this open-source project. +"""Define the Seq2Seq Generation Network.""" +import numpy as np +import torch +import torch.nn as nn + +from ...builder import DANCE_MODELS +from .layers import MultiHeadAttention, PositionwiseFeedForward + +BOS_POSE_AIST = np.array([ + 0.01340632513165474, 1.6259130239486694, -0.09833218157291412, + 0.0707249641418457, 1.5451008081436157, -0.12474726885557175, + -0.04773886129260063, 1.536355972290039, -0.11427298933267593, + 0.015812935307621956, 1.7525817155838013, -0.12864114344120026, + 0.13902147114276886, 1.1639258861541748, -0.0879698246717453, + -0.10036090016365051, 1.1553057432174683, -0.08047012239694595, + 0.006522613577544689, 1.8904004096984863, -0.10235153883695602, + 0.07891514897346497, 0.7553867101669312, -0.20340093970298767, + -0.037818294018507004, 0.7545002698898315, -0.1963980495929718, + 0.00045378319919109344, 1.9454832077026367, -0.09329807013273239, + 0.11616306006908417, 0.668250560760498, -0.0974099189043045, + -0.05322670564055443, 0.6652328968048096, -0.07871627062559128, + -0.014527007937431335, 2.159270763397217, -0.08067376166582108, + 0.0712718814611435, 2.0614874362945557, -0.08859370648860931, + -0.08343493938446045, 2.0597264766693115, -0.09117652475833893, + -0.002253010869026184, 2.244560718536377, -0.024742677807807922, + 0.19795098900794983, 2.098480463027954, -0.09858542680740356, + -0.20080527663230896, 2.0911219120025635, -0.0731159895658493, + 0.30632710456848145, 1.8656978607177734, -0.09286995232105255, + -0.3086402714252472, 1.8520605564117432, -0.06464222073554993, + 0.25927090644836426, 1.632638931274414, 0.02665536105632782, + -0.2640104591846466, 1.6051883697509766, 0.0331537127494812, + 0.2306937426328659, 1.5523173809051514, 0.051218822598457336, + -0.24223697185516357, 1.5211939811706543, 0.05606864392757416 +]) + +BOS_POSE_AIST_ROT = np.concatenate([ + np.array([0.05072649, 1.87570345, -0.24885127]), # root + np.concatenate([np.eye(3).reshape(-1)] * 24) +]) + + +def get_non_pad_mask(seq): + assert seq.dim() == 3 + non_pad_mask = torch.abs(seq).sum(2).ne(0).type(torch.float) + return non_pad_mask.unsqueeze(-1) + + +def get_attn_key_pad_mask(seq_k, seq_q): + """For masking out the padding part of key sequence.""" + len_q = seq_q.size(1) + padding_mask = torch.abs(seq_k).sum(2).eq( + 0) # sum the vector of last dim and then judge + padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, + -1) # b x lq x lk + + return padding_mask + + +def get_subsequent_mask(seq, sliding_windown_size): + """For masking out the subsequent info.""" + batch_size, seq_len, _ = seq.size() + mask = torch.ones((seq_len, seq_len), device=seq.device, dtype=torch.uint8) + + mask = torch.triu(mask, diagonal=-sliding_windown_size) + mask = torch.tril(mask, diagonal=sliding_windown_size) + mask = 1 - mask + # print(mask) + return mask.bool() + + +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + """Sinusoid position encoding table.""" + + def cal_angle(position, hid_idx): + return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid)] + + sinusoid_table = np.array( + [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) + + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0. + + return torch.FloatTensor(sinusoid_table) + + +class EncoderLayer(nn.Module): + """Compose with two layers.""" + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, slf_attn_mask=None, non_pad_mask=None): + + enc_output, enc_slf_attn = self.slf_attn( + enc_input, enc_input, enc_input, mask=slf_attn_mask) + # enc_output *= non_pad_mask + + enc_output = self.pos_ffn(enc_output) + # enc_output *= non_pad_mask + + return enc_output, enc_slf_attn + + +class Encoder(nn.Module): + """A encoder model with self attention mechanism.""" + + def __init__(self, + max_seq_len=1800, + input_size=20, + d_word_vec=10, + n_layers=6, + n_head=8, + d_k=64, + d_v=64, + d_model=10, + d_inner=256, + dropout=0.1): + + super().__init__() + + self.d_model = d_model + n_position = max_seq_len + 1 + + self.src_emb = nn.Linear(input_size, d_word_vec) + + self.position_enc = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), + freeze=True) + + self.layer_stack = nn.ModuleList([ + EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + + def forward(self, src_seq, src_pos, mask=None, return_attns=False): + + enc_slf_attn_list = [] + + # -- Forward + enc_output = self.src_emb(src_seq) + self.position_enc(src_pos).float() + + for enc_layer in self.layer_stack: + enc_output, enc_slf_attn = enc_layer( + enc_output, slf_attn_mask=mask) + + if return_attns: + enc_slf_attn_list += [enc_slf_attn] + + if return_attns: + return enc_output, enc_slf_attn_list + return enc_output, + + +class Decoder(nn.Module): + + def __init__(self, + input_size=274, + d_word_vec=150, + hidden_size=200, + dropout=0.1, + encoder_d_model=200, + rotmat=False): + super().__init__() + + self.input_size = input_size + self.d_word_vec = d_word_vec + self.hidden_size = hidden_size + + self.tgt_emb = nn.Linear(input_size, d_word_vec) + self.dropout = nn.Dropout(dropout) + self.encoder_d_model = encoder_d_model + + self.lstm1 = nn.LSTMCell(d_word_vec, hidden_size) + self.lstm2 = nn.LSTMCell(hidden_size, hidden_size) + self.lstm3 = nn.LSTMCell(hidden_size, hidden_size) + self.rotmat = rotmat + + def init_state(self, bsz, device): + c0 = torch.randn(bsz, self.hidden_size).to(device) + c1 = torch.randn(bsz, self.hidden_size).to(device) + c2 = torch.randn(bsz, self.hidden_size).to(device) + h0 = torch.randn(bsz, self.hidden_size).to(device) + h1 = torch.randn(bsz, self.hidden_size).to(device) + h2 = torch.randn(bsz, self.hidden_size).to(device) + + vec_h = [h0, h1, h2] + vec_c = [c0, c1, c2] + + if self.rotmat: + bos = BOS_POSE_AIST_ROT + bos = np.tile(bos, (bsz, 1)) + else: + bos = BOS_POSE_AIST + bos = np.tile(bos, (bsz, 1)) + root = bos[:, :3] + bos = bos - np.tile(root, (1, 24)) + bos[:, :3] = root + out_frame = torch.from_numpy(bos).float().to(device) + out_seq = torch.FloatTensor(bsz, 1).to(device) + + return (vec_h, vec_c), out_frame, out_seq + + def forward(self, in_frame, vec_h, vec_c): + + in_frame = self.tgt_emb(in_frame) + in_frame = self.dropout(in_frame) + + vec_h0, vec_c0 = self.lstm1(in_frame, (vec_h[0], vec_c[0])) + vec_h1, vec_c1 = self.lstm2(vec_h[0], (vec_h[1], vec_c[1])) + vec_h2, vec_c2 = self.lstm3(vec_h[1], (vec_h[2], vec_c[2])) + + vec_h_new = [vec_h0, vec_h1, vec_h2] + vec_c_new = [vec_c0, vec_c1, vec_c2] + return vec_h2, vec_h_new, vec_c_new + + +@DANCE_MODELS.register_module() +class DanceRevolution(nn.Module): + + def __init__(self, model_config): + super().__init__() + args = model_config + self.model_args = args + self._epoch = None + + encoder = Encoder( + max_seq_len=args.max_seq_len, + input_size=args.d_frame_vec, + d_word_vec=args.frame_emb_size, + n_layers=args.n_layers, + n_head=args.n_head, + d_k=args.d_k, + d_v=args.d_v, + d_model=args.d_model, + d_inner=args.d_inner, + dropout=args.dropout) + + decoder = Decoder( + input_size=args.d_pose_vec, + d_word_vec=args.pose_emb_size, + hidden_size=args.d_inner, + encoder_d_model=args.d_model, + dropout=args.dropout, + rotmat=args.rotmat) + + condition_step = args.condition_step + sliding_windown_size = args.sliding_windown_size + lambda_v = args.lambda_v + + self.encoder = encoder + self.decoder = decoder + self.linear = nn.Linear(decoder.hidden_size + encoder.d_model, + decoder.input_size) + + self.condition_step = condition_step + self.sliding_windown_size = sliding_windown_size + self.lambda_v = lambda_v + device = torch.device('cuda' if args.cuda else 'cpu') + self.device = device + + def init_decoder_hidden(self, bsz): + return self.decoder.init_state(bsz, self.device) + + # dynamic auto-condition + self-attention mask + def forward(self, src_seq, tgt_seq, epoch_i): + bsz, seq_len, _ = tgt_seq.size() + src_pos = (torch.arange(seq_len).long() + 1)[None].expand( + bsz, -1).detach().to(self.device) + hidden, dec_output, out_seq = self.init_decoder_hidden(tgt_seq.size(0)) + # forward + + vec_h, vec_c = hidden + + enc_mask = get_subsequent_mask(src_seq, self.sliding_windown_size) + enc_outputs, *_ = self.encoder(src_seq, src_pos, mask=enc_mask) + + groundtruth_mask = torch.ones(seq_len, self.condition_step) + prediction_mask = torch.zeros(seq_len, int(epoch_i * self.lambda_v)) + mask = torch.cat([prediction_mask, groundtruth_mask], + 1).view(-1)[:seq_len] # for random + + preds = [] + for i in range(seq_len): + dec_input = tgt_seq[:, i] if mask[i] == 1 else dec_output.detach( + ) # dec_output + dec_output, vec_h, vec_c = self.decoder(dec_input, vec_h, vec_c) + dec_output = torch.cat([dec_output, enc_outputs[:, i]], 1) + dec_output = self.linear(dec_output) + preds.append(dec_output) + + outputs = [z.unsqueeze(1) for z in preds] + outputs = torch.cat(outputs, dim=1) + return outputs + + def generate( + self, + src_seq, + ): + """Generate dance pose in one batch.""" + with torch.no_grad(): + # Use the pre-defined begin of pose (BOP) + # to generate whole sequence + bsz, src_seq_len, _ = src_seq.size() + src_pos = (torch.arange(src_seq_len).long() + 1)[None].expand( + bsz, -1).detach().to(self.device) + # bsz, tgt_seq_len, dim = tgt_seq.size() + tgt_seq_len = 1 + generated_frames_num = src_seq_len - tgt_seq_len + + hidden, dec_output, out_seq = self.init_decoder_hidden(bsz) + vec_h, vec_c = hidden + + enc_mask = get_subsequent_mask( + src_seq, self.model_args.sliding_windown_size) + enc_outputs, *_ = self.encoder(src_seq, src_pos, enc_mask) + + preds = [] + for i in range(tgt_seq_len): + # dec_input = tgt_seq[:, i] + dec_input = dec_output + dec_output, vec_h, vec_c = self.decoder( + dec_input, vec_h, vec_c) + dec_output = torch.cat([dec_output, enc_outputs[:, i]], 1) + dec_output = self.linear(dec_output) + preds.append(dec_output) + + for i in range(generated_frames_num): + dec_input = dec_output + dec_output, vec_h, vec_c = self.decoder( + dec_input, vec_h, vec_c) + dec_output = torch.cat( + [dec_output, enc_outputs[:, i + tgt_seq_len]], 1) + dec_output = self.linear(dec_output) + preds.append(dec_output) + + outputs = [z.unsqueeze(1) for z in preds] + outputs = torch.cat(outputs, dim=1) + return outputs + + def train_step(self, data, optimizer, **kwargs): + self.encoder.train() + self.decoder.train() + self.linear.train() + + aud_seq, pose_seq = data['music'], data['dance'] + + gold_seq = pose_seq[:, 1:] + src_aud = aud_seq[:, :-1] + src_pos = pose_seq[:, :-1] + + optimizer.zero_grad() + + output = self.forward(src_aud, src_pos, self._epoch) + + loss = torch.nn.functional.mse_loss(output, gold_seq) + + stats = {'loss': loss.item()} + + outputs = { + 'loss': loss, + 'log_vars': stats, + 'num_samples': output.size(1) + } + + return outputs + + def val_step(self, data, optimizer, **kwargs): + return self.test_step(data, optimizer, **kwargs) + + def test_step(self, data, optimizer, **kwargs): + results = [] + self.eval() + with torch.no_grad(): + aud_seq_eval, _ = data['music'], data['dance'] + + pose_seq_out = self.generate(aud_seq_eval) + results.append(pose_seq_out) + outputs = { + 'output_pose': results[0], + 'file_name': data['file_names'][0] + } + return outputs diff --git a/xrmogen/utils/logger.py b/xrmogen/utils/logger.py index 5bea519..66fb588 100644 --- a/xrmogen/utils/logger.py +++ b/xrmogen/utils/logger.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging - from mmcv.utils import get_logger