diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..96857a4 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,29 @@ +version: 2.1 + +jobs: + python_lint: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: + command: | + pip install --user --progress-bar off flake8 typing + flake8 . + + test: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: + command: | + pip install --user --progress-bar off scipy pytest + pip install --user --progress-bar off --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + pytest . + +workflows: + build: + jobs: + - python_lint + - test diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..0f7ad8b --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. +Please read the [full text](https://code.fb.com/codeofconduct/) +so that you can understand what actions will and will not be tolerated. diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..b3181ee --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,39 @@ +# Contributing to DETR +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process +Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Coding Style +* 4 spaces for indentation rather than tabs +* 80 character line length +* PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) + +## License +By contributing to DETR, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/.github/DETR.png b/.github/DETR.png new file mode 100644 index 0000000..a5a4a34 Binary files /dev/null and b/.github/DETR.png differ diff --git a/.github/ISSUE_TEMPLATE/bugs.md b/.github/ISSUE_TEMPLATE/bugs.md new file mode 100644 index 0000000..c3a1e90 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bugs.md @@ -0,0 +1,32 @@ +--- +name: "🐛 Bugs" +about: Report bugs in DETR +title: Please read & provide the following + +--- + +## Instructions To Reproduce the 🐛 Bug: + +1. what changes you made (`git diff`) or what code you wrote +``` + +``` +2. what exact command you run: +3. what you observed (including __full logs__): +``` + +``` +4. please simplify the steps as much as possible so they do not require additional resources to + run, such as a private dataset. + +## Expected behavior: + +If there are no obvious error in "what you observed" provided above, +please tell us the expected behavior. + +## Environment: + +Provide your environment information using the following command: +``` +python -m torch.utils.collect_env +``` diff --git a/.github/ISSUE_TEMPLATE/questions-help-support.md b/.github/ISSUE_TEMPLATE/questions-help-support.md new file mode 100644 index 0000000..fc708dc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions-help-support.md @@ -0,0 +1,22 @@ +--- +name: "How to do something❓" +about: How to do something using DETR? + +--- + +## ❓ How to do something using DETR + +Describe what you want to do, including: +1. what inputs you will provide, if any: +2. what outputs you are expecting: + + +NOTE: + +1. Only general answers are provided. + If you want to ask about "why X did not work", please use the + [Unexpected behaviors](https://github.com/facebookresearch/detr/issues/new/choose) issue template. + +2. About how to implement new models / new dataloader / new training logic, etc., check documentation first. + +3. We do not answer general machine learning / computer vision questions that are not specific to DETR, such as how a model works, how to improve your training/make it converge, or what algorithm/methods can be used to achieve X. diff --git a/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md b/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md new file mode 100644 index 0000000..c392409 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md @@ -0,0 +1,41 @@ +--- +name: "Unexpected behaviors" +about: Run into unexpected behaviors when using DETR +title: Please read & provide the following + +--- + +If you do not know the root cause of the problem, and wish someone to help you, please +post according to this template: + +## Instructions To Reproduce the Issue: + +1. what changes you made (`git diff`) or what code you wrote +``` + +``` +2. what exact command you run: +3. what you observed (including __full logs__): +``` + +``` +4. please simplify the steps as much as possible so they do not require additional resources to + run, such as a private dataset. + +## Expected behavior: + +If there are no obvious error in "what you observed" provided above, +please tell us the expected behavior. + +If you expect the model to converge / work better, note that we do not give suggestions +on how to train a new model. +Only in one of the two conditions we will help with it: +(1) You're unable to reproduce the results in DETR model zoo. +(2) It indicates a DETR bug. + +## Environment: + +Provide your environment information using the following command: +``` +python -m torch.utils.collect_env +``` diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..217b9be --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +.nfs* +*.ipynb +*.pyc +.dumbo.json +.DS_Store +.*.swp +*.pth +**/__pycache__/** +.ipynb_checkpoints/ +datasets/data/ +experiment-* +*.tmp +*.pkl +**/.mypy_cache/* +.mypy_cache/* +not_tracked_dir/ +.vscode +.python-version +*.sbatch +*.egg-info +build +dist +.idea diff --git a/.run/STTran_train.run.xml b/.run/STTran_train.run.xml new file mode 100644 index 0000000..cafc255 --- /dev/null +++ b/.run/STTran_train.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/STTran_train_frcnn.run.xml b/.run/STTran_train_frcnn.run.xml new file mode 100644 index 0000000..26280b3 --- /dev/null +++ b/.run/STTran_train_frcnn.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/STTran_train_frcnn_IOU.run.xml b/.run/STTran_train_frcnn_IOU.run.xml new file mode 100644 index 0000000..f69fa38 --- /dev/null +++ b/.run/STTran_train_frcnn_IOU.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/demo_CP.run.xml b/.run/demo_CP.run.xml new file mode 100644 index 0000000..bddd2d1 --- /dev/null +++ b/.run/demo_CP.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/demo_trackformer.run.xml b/.run/demo_trackformer.run.xml new file mode 100644 index 0000000..a6af023 --- /dev/null +++ b/.run/demo_trackformer.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/generate_coco_from_VidHOI.run.xml b/.run/generate_coco_from_VidHOI.run.xml new file mode 100644 index 0000000..efd0d57 --- /dev/null +++ b/.run/generate_coco_from_VidHOI.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/generate_coco_from_actiongenome.run.xml b/.run/generate_coco_from_actiongenome.run.xml new file mode 100644 index 0000000..fd251d1 --- /dev/null +++ b/.run/generate_coco_from_actiongenome.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/generate_coco_from_crowdhuman.run.xml b/.run/generate_coco_from_crowdhuman.run.xml new file mode 100644 index 0000000..1c0bafd --- /dev/null +++ b/.run/generate_coco_from_crowdhuman.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/generate_coco_from_mot.run.xml b/.run/generate_coco_from_mot.run.xml new file mode 100644 index 0000000..11b21da --- /dev/null +++ b/.run/generate_coco_from_mot.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/track.run.xml b/.run/track.run.xml new file mode 100644 index 0000000..8008421 --- /dev/null +++ b/.run/track.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_actiongenome_detr+hoi.run.xml b/.run/train_actiongenome_detr+hoi.run.xml new file mode 100644 index 0000000..355abaf --- /dev/null +++ b/.run/train_actiongenome_detr+hoi.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/train_actiongenome_detr+tracking.run.xml b/.run/train_actiongenome_detr+tracking.run.xml new file mode 100644 index 0000000..2357da8 --- /dev/null +++ b/.run/train_actiongenome_detr+tracking.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/train_actiongenome_detr.run.xml b/.run/train_actiongenome_detr.run.xml new file mode 100644 index 0000000..6d81578 --- /dev/null +++ b/.run/train_actiongenome_detr.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_actiongenome_ngpus.run.xml b/.run/train_actiongenome_ngpus.run.xml new file mode 100644 index 0000000..d740f2b --- /dev/null +++ b/.run/train_actiongenome_ngpus.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_consistent_pairing.run.xml b/.run/train_consistent_pairing.run.xml new file mode 100644 index 0000000..4423bbf --- /dev/null +++ b/.run/train_consistent_pairing.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_consistent_pairing_crowdhuman.run.xml b/.run/train_consistent_pairing_crowdhuman.run.xml new file mode 100644 index 0000000..98e2ff7 --- /dev/null +++ b/.run/train_consistent_pairing_crowdhuman.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_ngpu.run.xml b/.run/train_ngpu.run.xml new file mode 100644 index 0000000..f87ad65 --- /dev/null +++ b/.run/train_ngpu.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/train_vidhoi_detr+hoi.run.xml b/.run/train_vidhoi_detr+hoi.run.xml new file mode 100644 index 0000000..eb0b3ae --- /dev/null +++ b/.run/train_vidhoi_detr+hoi.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_vidhoi_detr+tracking+hoi(vsgg).run.xml b/.run/train_vidhoi_detr+tracking+hoi(vsgg).run.xml new file mode 100644 index 0000000..bd74b22 --- /dev/null +++ b/.run/train_vidhoi_detr+tracking+hoi(vsgg).run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/.run/train_vidhoi_detr+tracking.run.xml b/.run/train_vidhoi_detr+tracking.run.xml new file mode 100644 index 0000000..c3b1eff --- /dev/null +++ b/.run/train_vidhoi_detr+tracking.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_vidhoi_detr.run.xml b/.run/train_vidhoi_detr.run.xml new file mode 100644 index 0000000..0b11534 --- /dev/null +++ b/.run/train_vidhoi_detr.run.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/.run/train_vidhoi_ngpu.run.xml b/.run/train_vidhoi_ngpu.run.xml new file mode 100644 index 0000000..2c83169 --- /dev/null +++ b/.run/train_vidhoi_ngpu.run.xml @@ -0,0 +1,24 @@ + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..59bf696 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2023] [authors of TPT] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..d97dba5 --- /dev/null +++ b/README.md @@ -0,0 +1,61 @@ +# End-to-End Video Scene Graph Generation with Temporal Propagation Transformer + +This repository provides the official implementation of the [End-to-End Video Scene Graph Generation with Temporal Propagation Transformer](https://ieeexplore.ieee.org/abstract/document/10145598) paper. + +## Installation + +1. `pip3 install -r requirements.txt` +2. Install PyTorch>=1.5 and torchvision>=0.6 from [here](https://pytorch.org/get-started/previous-versions/#v150). +3. `pip install pycocotools` +4. Install MultiScaleDeformableAttention package: `python src/trackformer/models/ops/setup.py build --build-base=src/trackformer/models/ops/ install` + +## Data preparation + +### ActionGenome + +1.Preprocess and dump frames following https://github.com/JingweiJ/ActionGenome + +2.Convert to COCO annotation format using `python src/generate_coco_from_actiongenome.py` + +### VidHOI + +1.Download and prepare VidHOI following https://github.com/coldmanck/VidHOI + +2.Dump frames +``` +python src/generate_coco_from_vidhoi.py --task dump_frames +``` + +3.Convert to COCO annotations format +``` +python src/generate_coco_from_vidhoi.py --task convert_coco_annotations +``` + +## Training & Evaluation + +All the running scripts are in `./runs` directory. + +1.Tain a Transformer-based detector for object detection in individual video frames. +``` +sh ./runs/vidhoi_detr.sh your_output_dir +``` + +2.Fine-tune the Transformer-based detector together with the QPM module, to further build temporal associations of detected instances. +``` +sh ./runs/vidhoi_detr+tracking.sh your_output_dir +``` + +3.Freeze all parameters of the architecture learnt in previous step, and only optimize the modules for relation recognition. +``` +sh ./runs/vidhoi_detr+tracking+hoi.sh your_output_dir +``` + +4.Jointly fine-tune the whole framework. +``` +# Set freeze_detr=False in ./runs/vidhoi_detr+hoi.sh +sh ./runs/vidhoi_detr+hoi.sh your_output_dir +``` + +## Acknowledgement + +The codebase builds upon [DETR](https://github.com/facebookresearch/detr), [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR), [TrackFormer](https://github.com/timmeinhardt/trackformer), [STTran](https://github.com/timmeinhardt/trackformer) and [ByteTrack](https://github.com/ifzhang/ByteTrack). Thanks for their wonderful works. diff --git a/cfgs/consistent_pairing.yaml b/cfgs/consistent_pairing.yaml new file mode 100644 index 0000000..0bb57e2 --- /dev/null +++ b/cfgs/consistent_pairing.yaml @@ -0,0 +1,5 @@ +track_query_propagation_strategy: consistent_pairing +clip_length: 2 +token_propagation_sample_rate: 0 +track_query_false_negative_prob: 0 +output_dir: models/consistent_pairing diff --git a/cfgs/submit.yaml b/cfgs/submit.yaml new file mode 100644 index 0000000..b6c60e7 --- /dev/null +++ b/cfgs/submit.yaml @@ -0,0 +1,20 @@ +# Number of gpus to request on each node +num_gpus: 1 +vram: 12GB +# memory allocated per GPU in GB +mem_per_gpu: 20 +# Number of nodes to request +nodes: 1 +# Duration of the job +timeout: 4320 +# Job dir. Leave empty for automatic. +job_dir: '' +# Use to run jobs locally. ('debug', 'local', 'slurm') +cluster: debug +# Partition. Leave empty for automatic. +slurm_partition: '' +# Constraint. Leave empty for automatic. +slurm_constraint: '' +slurm_comment: '' +slurm_gres: '' +slurm_exclude: '' \ No newline at end of file diff --git a/cfgs/track.yaml b/cfgs/track.yaml new file mode 100644 index 0000000..235e098 --- /dev/null +++ b/cfgs/track.yaml @@ -0,0 +1,46 @@ +output_dir: null +verbose: false +seed: 666 + +obj_detect_checkpoint_file: models/pretrained/mot17_train_deformable_private/checkpoint.pth + +interpolate: False +# if available load tracking results and only evaluate +load_results_dir: null + +# dataset (look into src/datasets/tracking/factory.py) +dataset_name: MOT17-ALL-ALL +data_root_dir: data +track_query_propagation_strategy: 'trackformer' + +# [False, 'debug', 'pretty'] +# compile video with: `ffmpeg -f image2 -framerate 15 -i %06d.jpg -vcodec libx264 -y movie.mp4 -vf scale=320:-1` +write_images: False +# Maps are only visualized if write_images is True +generate_attention_maps: False + +# track, evaluate and write images only for a range of frames (in float fraction) +frame_range: + start: 0.0 + end: 1.0 + +tracker_cfg: + # [False, 'center_distance', 'min_iou_0_5'] + public_detections: False + # score threshold for detections + detection_obj_score_thresh: 0.9 + # score threshold for keeping the track alive + track_obj_score_thresh: 0.8 + # NMS threshold for detection + detection_nms_thresh: 0.9 + # NMS theshold while tracking + track_nms_thresh: 0.9 + # motion model settings + # How many timesteps inactive tracks are kept and cosidered for reid + inactive_patience: 0 + # How similar do image and old track need to be to be considered the same person + reid_sim_threshold: 0.0 + reid_sim_only: false + reid_score_thresh: 0.8 + reid_greedy_matching: false + consistent_pairing_detection_thresh: 0.5 diff --git a/cfgs/track_reid.yaml b/cfgs/track_reid.yaml new file mode 100644 index 0000000..0a08a4f --- /dev/null +++ b/cfgs/track_reid.yaml @@ -0,0 +1,2 @@ +tracker_cfg: + inactive_patience: 5 diff --git a/cfgs/train.yaml b/cfgs/train.yaml new file mode 100644 index 0000000..345350c --- /dev/null +++ b/cfgs/train.yaml @@ -0,0 +1,165 @@ +object_detector: 'detr' +num_classes: 1 +lr: 0.0001 +lr_backbone_names: ['backbone.0'] +lr_backbone: 0.00001 +lr_linear_proj_names: ['reference_points', 'sampling_offsets'] +lr_linear_proj_mult: 0.1 +lr_track: 0.0001 +batch_size: 1 +weight_decay: 0.0001 +epochs: 50 +lr_drop: 40 +# gradient clipping max norm +clip_max_norm: 0.1 +# Deformable DETR +deformable: false +with_box_refine: false +two_stage: false +# Model parameters +freeze_detr: false +load_mask_head_from_model: null +# Backbone +# Name of the convolutional backbone to use. ('resnet50', 'resnet101') +backbone: resnet101 +# If true, we replace stride with dilation in the last convolutional block (DC5) +dilation: false +# Type of positional embedding to use on top of the image features. ('sine', 'learned') +position_embedding: sine +# Number of feature levels the encoder processes from the backbone +num_feature_levels: 1 +# Transformer +# Number of encoding layers in the transformer +enc_layers: 6 +# Number of decoding layers in the transformer +dec_layers: 6 +# Intermediate size of the feedforward layers in the transformer blocks +dim_feedforward: 2048 +# Size of the embeddings (dimension of the transformer) +hidden_dim: 256 +# Dropout applied in the transformer +dropout: 0.1 +# Number of attention heads inside the transformer's attentions +nheads: 8 +# Number of object queries +num_queries: 100 +pre_norm: false +dec_n_points: 4 +enc_n_points: 4 + +# Tracking +tracking: false +# In addition to detection also run tracking evaluation with default configuration from `cfgs/track.yaml` +tracking_eval: true +# Range of possible random previous frames +track_prev_frame_range: 0 +track_prev_frame_rnd_augs: 0.01 +track_query_false_positive_prob: 0.1 +track_query_false_negative_prob: 0.4 +track_query_false_positive_eos_weight: true +track_query_noise: 0.0 +track_attention: false + +track_query_propagation_strategy: 'trackformer' +tracking_token_propagation: true +tracking_match_propagation: true +tracking_match_propagation_skip_frame: false +token_propagation_sample_rate: 0.1 + +# Segmentation +masks: false +# Matcher +# Class coefficient in the matching cost +set_cost_class: 1.0 +# L1 box coefficient in the matching cost +set_cost_bbox: 5.0 +# giou box coefficient in the matching cost +set_cost_giou: 2.0 +# Loss +# Disables auxiliary decoding losses (loss at each layer) +aux_loss: true +mask_loss_coef: 1.0 +dice_loss_coef: 1.0 +cls_loss_coef: 1.0 +bbox_loss_coef: 5.0 +giou_loss_coef: 2 +# Relative classification weight of the no-object class +eos_coef: 0.1 +focal_loss: false +focal_alpha: 0.25 +# Dataset +dataset: coco +train_split: train +val_split: val +coco_path: data/coco_2017 +coco_panoptic_path: null +mot_path: data/MOT17 +crowdhuman_path: data/CrowdHuman +actiongenome_path: data/ActionGenome +vidhoi_path: data/VidHOI +clip_length: None # load as video clip +# allows for joint training of mot and crowdhuman with +# the `mot_crowdhuman` dataset +crowdhuman_train_split: null +coco_and_crowdhuman_prev_frame_rnd_augs: 0.05 +img_transform: + max_size: 1333 + val_width: 800 +# Miscellaneous +# path where to save, empty for no saving +output_dir: models/debug +# device to use for training / testing +device: cuda +seed: 42 +# resume from checkpoint +resume: '' +resume_shift_neuron: False +# resume optimization from checkpoint +resume_optim: false +# resume Visdom visualization +resume_vis: false +start_epoch: 1 +eval_only: false +eval_train: false +num_workers: 2 +val_interval: 1 +debug: false +# epoch interval for model saving. if 0 only save last and best models +save_model_interval: 0 +# distributed training parameters +# number of distributed processes +world_size: 1 +# url used to set up distributed training +dist_url: env:// +# visualization parameters +# Visdom port. +vis_port: 8090 +# Visdom server URL. +vis_server: http://localhost +no_vis: true +vis_and_log_interval: 50 + +# HOI detection +hoi_detection: false +num_hoi_queries: 16 +num_relations: 26 +hoi_dec_layers: 4 +hoi_aux_loss: true +hoi_hard_mining: false +video_sgg_train: false +video_sgg_eval: false +hoi_use_interaction_decoder: true +hoi_use_temporal_dynamics: false +hoi_use_temporal_dynamics_prev_length: 2 +hoi_oracle_mode: false +hoi_oracle_mode_only_given_bbox: false # SGCls mode for SGG +hoi_oracle_mode_use_instant_trajectory: false +hoi_oracle_mode_use_roialign_union_feat: false +hoi_inference_apply_nms: true +hoi_instance_fuse_spatial_and_semantic_feat: true +hoi_relation_propagation_on_inference: false + +# STTran +sgg_use_STTran: false +sgg_mode: 'sgdet' +sgg_postprocessing_tracker: null diff --git a/cfgs/train_actiongenome.yaml b/cfgs/train_actiongenome.yaml new file mode 100644 index 0000000..2ec9094 --- /dev/null +++ b/cfgs/train_actiongenome.yaml @@ -0,0 +1,12 @@ +dataset: actiongenome +train_split: train +val_split: test +num_classes: 36 +#train_split: train_v1000 +#val_split: test_v200 + +# train from COCO pre-trained +resume: models/pretrained/r50_deformable_detr-checkpoint.pth +lr: 0.00001 +lr_backbone: 0.000001 +lr_drop: 20 diff --git a/cfgs/train_coco_person_masks.yaml b/cfgs/train_coco_person_masks.yaml new file mode 100644 index 0000000..9d99794 --- /dev/null +++ b/cfgs/train_coco_person_masks.yaml @@ -0,0 +1,10 @@ +dataset: coco_person + +resume: models/mot17_train_pretrain_CH_deformable/checkpoint.pth +load_mask_head_from_model: models/detr-r50-panoptic-00ce5173.pth +freeze_detr: true +masks: true + +lr: 0.0001 +lr_drop: 50 +epochs: 50 \ No newline at end of file diff --git a/cfgs/train_crowdhuman.yaml b/cfgs/train_crowdhuman.yaml new file mode 100644 index 0000000..8dfb40b --- /dev/null +++ b/cfgs/train_crowdhuman.yaml @@ -0,0 +1,12 @@ +dataset: mot_crowdhuman +crowdhuman_train_split: train_val +train_split: null +val_split: mot17_train_coco + +# train from COCO pre-trained +resume: models/pretrained/r50_deformable_detr-checkpoint.pth +lr: 0.0001 +lr_backbone: 0.00001 + +epochs: 50 +lr_drop: 20 \ No newline at end of file diff --git a/cfgs/train_deformable.yaml b/cfgs/train_deformable.yaml new file mode 100644 index 0000000..995420f --- /dev/null +++ b/cfgs/train_deformable.yaml @@ -0,0 +1,5 @@ +deformable: true +backbone: resnet50 +num_feature_levels: 4 +num_queries: 300 +dim_feedforward: 1024 \ No newline at end of file diff --git a/cfgs/train_focal_loss.yaml b/cfgs/train_focal_loss.yaml new file mode 100644 index 0000000..3d58bec --- /dev/null +++ b/cfgs/train_focal_loss.yaml @@ -0,0 +1,4 @@ +focal_loss: true +focal_alpha: 0.25 +cls_loss_coef: 2.0 +set_cost_class: 2.0 \ No newline at end of file diff --git a/cfgs/train_frcnn.yaml b/cfgs/train_frcnn.yaml new file mode 100644 index 0000000..48f8c03 --- /dev/null +++ b/cfgs/train_frcnn.yaml @@ -0,0 +1,8 @@ +object_detector: 'frcnn' +batch_size: 2 + +lr: 0.001 +weight_decay: 0.0005 + +epochs: 10 +lr_drop: 7 diff --git a/cfgs/train_full_res.yaml b/cfgs/train_full_res.yaml new file mode 100644 index 0000000..cf90ac4 --- /dev/null +++ b/cfgs/train_full_res.yaml @@ -0,0 +1,3 @@ +img_transform: + max_size: 1920 + val_width: 1080 \ No newline at end of file diff --git a/cfgs/train_hoi.yaml b/cfgs/train_hoi.yaml new file mode 100644 index 0000000..d29db5e --- /dev/null +++ b/cfgs/train_hoi.yaml @@ -0,0 +1,12 @@ +resume: xxx # fine-tuned DETR on customized dataset +batch_size: 1 + +tracking_token_propagation: False +hoi_detection: True +freeze_detr: True + +lr: 0.00001 +lr_backbone: 0.000001 +lr_drop: 4 +epochs: 7 +save_model_interval: 1 diff --git a/cfgs/train_mot17.yaml b/cfgs/train_mot17.yaml new file mode 100644 index 0000000..61a3306 --- /dev/null +++ b/cfgs/train_mot17.yaml @@ -0,0 +1,10 @@ +dataset: mot +train_split: mot17_train_coco +val_split: mot17_train_coco + +# fine-tune +resume: xxx +lr: 0.00001 +lr_backbone: 0.000001 +epochs: 30 +lr_drop: 20 diff --git a/cfgs/train_mot17_cross_val.yaml b/cfgs/train_mot17_cross_val.yaml new file mode 100644 index 0000000..b5a1335 --- /dev/null +++ b/cfgs/train_mot17_cross_val.yaml @@ -0,0 +1,11 @@ +dataset: mot +train_split: mot17_train_cross_val_frame_0_0_to_0_5_coco +val_split: mot17_train_cross_val_frame_0_5_to_1_0_coco + +# train from COCO pre-trained (for ablations) +resume: models/pretrained/r50_deformable_detr-checkpoint.pth +lr: 0.0001 +lr_backbone: 0.00001 + +epochs: 30 +lr_drop: 20 diff --git a/cfgs/train_mots20.yaml b/cfgs/train_mots20.yaml new file mode 100644 index 0000000..cd934d6 --- /dev/null +++ b/cfgs/train_mots20.yaml @@ -0,0 +1,12 @@ +dataset: mot +mot_path: data/MOTS20 +train_split: mots20_train_coco +val_split: mots20_train_coco + +resume: models/mot17_train_pretrain_CH_deformable_with_coco_person_masks/checkpoint.pth +masks: true +lr: 0.00001 +lr_backbone: 0.000001 + +epochs: 40 +lr_drop: 40 \ No newline at end of file diff --git a/cfgs/train_sttran.yaml b/cfgs/train_sttran.yaml new file mode 100644 index 0000000..a9aa79f --- /dev/null +++ b/cfgs/train_sttran.yaml @@ -0,0 +1,10 @@ +object_detector: 'frcnn' +sgg_use_STTran: True + +clip_length: 3 +batch_size: 1 +lr: 0.0001 +weight_decay: 0.0005 + +epochs: 10 +lr_drop: 6 diff --git a/cfgs/train_tracking.yaml b/cfgs/train_tracking.yaml new file mode 100644 index 0000000..7228a5e --- /dev/null +++ b/cfgs/train_tracking.yaml @@ -0,0 +1,5 @@ +tracking: true +tracking_eval: true +track_prev_frame_range: 5 +track_query_false_positive_eos_weight: true +track_query_propagation_strategy: 'trackformer' diff --git a/cfgs/train_vidhoi.yaml b/cfgs/train_vidhoi.yaml new file mode 100644 index 0000000..c42e4f5 --- /dev/null +++ b/cfgs/train_vidhoi.yaml @@ -0,0 +1,10 @@ +dataset: vidhoi +num_classes: 78 +num_relations: 50 +train_split: train +val_split: validation +#train_split: train_v30 +#val_split: validation_v10 + +# train from COCO pre-trained +resume: models/pretrained/r50_deformable_detr-checkpoint.pth diff --git a/cfgs/train_vsgg.yaml b/cfgs/train_vsgg.yaml new file mode 100644 index 0000000..df37b18 --- /dev/null +++ b/cfgs/train_vsgg.yaml @@ -0,0 +1,20 @@ +resume: xxx # fine-tuned DETR or pretrained tracker on customized dataset + +tracking_eval: false +video_sgg_eval: true +video_sgg_train: true +track_query_propagation_strategy: consistent_pairing +#track_query_false_positive_eos_weight: false +token_propagation_sample_rate: 0 + +clip_length: 3 +lr: 0.00001 +lr_backbone: 0.000001 +lr_drop: 4 +epochs: 7 +save_model_interval: 1 + +## for stage2 training +#hoi_detection: True +#freeze_detr: True +#hoi_use_temporal_dynamics: True diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..bfcda32 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,3 @@ +* +!.gitignore +!snakeboard diff --git a/docs/INSTALL.md b/docs/INSTALL.md new file mode 100644 index 0000000..1c6a933 --- /dev/null +++ b/docs/INSTALL.md @@ -0,0 +1,62 @@ +# Installation + +1. Clone and enter this repository: + ``` + git clone git@github.com:timmeinhardt/trackformer.git + cd trackformer + ``` + +2. Install packages for Python 3.7: + + 1. `pip3 install -r requirements.txt` + 2. Install PyTorch 1.5 and torchvision 0.6 from [here](https://pytorch.org/get-started/previous-versions/#v150). + 3. Fix:`pip install pycocotools` !!! (Òrginally version get different eval results: Install pycocotools (with fixed ignore flag): `pip3 install -U 'git+https://github.com/timmeinhardt/cocoapi.git#subdirectory=PythonAPI'`) + 4. Install MultiScaleDeformableAttention package: `python src/trackformer/models/ops/setup.py build --build-base=src/trackformer/models/ops/ install` + +3. Download and unpack datasets in the `data` directory: + + 1. [MOT17](https://motchallenge.net/data/MOT17/): + ``` + wget https://motchallenge.net/data/MOT17.zip + unzip MOT17.zip + python src/generate_coco_from_mot.py + ``` + 2. (Optional) [MOTS20](https://motchallenge.net/data/MOTS/): + ``` + wget https://motchallenge.net/data/MOTS.zip + unzip MOTS.zip + python src/generate_coco_from_mot.py --mots + ``` + 2. (Optional) [CrowdHuman](https://www.crowdhuman.org/download.html): + + 1. Create a `CrowdHuman` and `CrowdHuman/annotations` directory. + 2. Download and extract the `train` and `val` datasets including their corresponding `*.odgt` annotation file into the `CrowdHuman` directory. + 3. Create a `CrowdHuman/train_val` directory and merge or symlink the `train` and `val` image folders. + 4. Run `python src/generate_coco_from_crowdhuman.py` + 5. The final folder structure should resemble this: + ~~~ + |-- data + |-- CrowdHuman + | |-- train + | | |-- *.jpg + | |-- val + | | |-- *.jpg + | |-- train_val + | | |-- *.jpg + | |-- annotations + | | |-- annotation_train.odgt + | | |-- annotation_val.odgt + | | |-- train_val.json + ~~~ + +3. Download and unpack pretrained TrackFormer model files in the `models` directory: + ``` + wget https://vision.in.tum.de/webshare/u/meinhard/trackformer_models.zip + unzip trackformer_models.zip + ``` + +4. (optional) The evaluation of MOTS20 metrics requires two steps: + 1. Run Trackformer with `src/track.py` and output prediction files + 2. Download the official MOTChallenge [devkit](https://github.com/dendorferpatrick/MOTChallengeEvalKit) and run the MOTS evaluation on the prediction files + +In order to configure, log and reproduce our computational experiments, we structure our code with the [Sacred](http://sacred.readthedocs.io/en/latest/index.html) framework. For a detailed explanation of the Sacred interface please read its documentation. diff --git a/docs/MOT17-03-SDP.gif b/docs/MOT17-03-SDP.gif new file mode 100644 index 0000000..252167a Binary files /dev/null and b/docs/MOT17-03-SDP.gif differ diff --git a/docs/MOTS20-07.gif b/docs/MOTS20-07.gif new file mode 100644 index 0000000..3885c15 Binary files /dev/null and b/docs/MOTS20-07.gif differ diff --git a/docs/TRAIN.md b/docs/TRAIN.md new file mode 100644 index 0000000..f3e0b37 --- /dev/null +++ b/docs/TRAIN.md @@ -0,0 +1,118 @@ +# Train TrackFormer + +We provide the code as well as intermediate models of our entire training pipeline for multiple datasets. Monitoring of the training/evaluation progress is possible via command line as well as [Visdom](https://github.com/fossasia/visdom.git). For the latter, a Visdom server must be running at `vis_port=8090` and `vis_server=http://localhost` (see `cfgs/train.yaml`). To deactivate Visdom logging run a training with the `no_vis=True` flag. + +
+ Snakeboard demo +
+ +The settings for each dataset are specified in the respective configuration files, e.g., `cfgs/train_crowdhuman.yaml`. + +## CrowdHuman pre-training + +``` +python src/train.py with \ + deformable \ + tracking \ + crowdhuman \ + full_res \ + output_dir=models/crowdhuman_train_val_deformable_v2 \ +``` + +## MOT17 + +#### Private detections + +``` +python src/train.py with \ + deformable \ + tracking \ + mot17 \ + full_res \ + resume=models/crowdhuman_train_val_deformable/checkpoint.pth \ + output_dir=models/mot17_train_deformable_private_v2 \ +``` + +#### Public detections + +``` +python src/train.py with \ + deformable \ + tracking \ + mot17 \ + full_res \ + resume=models/r50_deformable_detr-checkpoint.pth \ + output_dir=models/mot17_train_deformable_public_v2 \ + epochs=40 \ + lr_drop=10 +``` + +## MOTS20 + +For our MOTS20 test set submission, we finetune a MOT17 private detection model without deformable attention, i.e., vanilla DETR, which was pre-trained on the CrowdHuman dataset. The finetuning itself conists of two training steps: (i) the original DETR panoptic segmentation head on the COCO person segmentation data and (ii) the entire TrackFormer model (including segmentation head) on the MOTS20 training set. + +``` +python src/train.py with \ + tracking \ + coco_person_masks \ + output_dir=models/mot17_train_private_coco_person_masks_v2 \ +``` + +``` +python src/train.py with \ + tracking \ + mots20 \ + output_dir=models/mots20_train_masks_v2 \ +``` + +### Ablation studies + +Will be added after acceptance of the paper. + +## Custom Dataset + +TrackFormer can be trained on additional/new object detection or multi-object tracking datasets without changing our codebase. The `crowdhuman` or `mot` datasets merely require a [COCO style](https://www.immersivelimit.com/tutorials/create-coco-annotations-from-scratch) annotation file and the following folder structure: + +~~~ +|-- data + |-- custom_dataset + | |-- train + | | |-- *.jpg + | |-- val + | | |-- *.jpg + | |-- annotations + | | |-- train.json + | | |-- val.json +~~~ + +In the case of a multi-object tracking dataset, the original COCO annotations style must be extended with `seq_length`, `first_frame_image_id` and `track_id` fields. See the `src/generate_coco_from_mot.py` script for details. For example, the following command finetunes our `MOT17` private model for additional 20 epochs on a custom dataset: + +``` +python src/train.py with \ + deformable \ + tracking \ + mot17 \ + full_res \ + resume=models/mot17_train_deformable_private/checkpoint.pth \ + output_dir=models/custom_dataset_train_deformable \ + mot_path=data/custom_dataset \ + train_split=train \ + val_split=val \ + epochs=20 \ +``` + +## Run with Submitit + +Furthermore, we provide a script for starting Slurm jobs with [submitit](https://github.com/facebookincubator/submitit). This includes a convenient command line interface for Slurm options as well as preemption and resuming capabilities. The aforementioned CrowdHuman pre-training can be executed on 8 x 16 GB GPUs with the following command: + +``` +python src/run_with_submitit.py with \ + num_gpus=8 \ + vram=16GB \ + cluster=slurm \ + train.deformable \ + train.tracking \ + train.crowdhuman \ + train.full_res \ + train.output_dir=models/crowdhuman_train_val_deformable_v2 \ +``` \ No newline at end of file diff --git a/docs/method.png b/docs/method.png new file mode 100644 index 0000000..ff967b4 Binary files /dev/null and b/docs/method.png differ diff --git a/docs/snakeboard.gif b/docs/snakeboard.gif new file mode 100644 index 0000000..6defbd9 Binary files /dev/null and b/docs/snakeboard.gif differ diff --git a/docs/trackformer_README.md b/docs/trackformer_README.md new file mode 100644 index 0000000..2e834c2 --- /dev/null +++ b/docs/trackformer_README.md @@ -0,0 +1,120 @@ +# TrackFormer: Multi-Object Tracking with Transformers + +This repository provides the official implementation of the [TrackFormer: Multi-Object Tracking with Transformers](https://arxiv.org/abs/2101.02702) paper by [Tim Meinhardt](https://dvl.in.tum.de/team/meinhardt/), [Alexander Kirillov](https://alexander-kirillov.github.io/), [Laura Leal-Taixe](https://dvl.in.tum.de/team/lealtaixe/) and [Christoph Feichtenhofer](https://feichtenhofer.github.io/). The codebase builds upon [DETR](https://github.com/facebookresearch/detr), [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR) and [Tracktor](https://github.com/phil-bergmann/tracking_wo_bnw). + + + +
+ MOT17-03-SDP + MOTS20-07 +
+ +## Abstract + +The challenging task of multi-object tracking (MOT) requires simultaneous reasoning about track initialization, identity, and spatiotemporal trajectories. +We formulate this task as a frame-to-frame set prediction problem and introduce TrackFormer, an end-to-end MOT approach based on an encoder-decoder Transformer architecture. +Our model achieves data association between frames via attention by evolving a set of track predictions through a video sequence. +The Transformer decoder initializes new tracks from static object queries and autoregressively follows existing tracks in space and time with the new concept of identity preserving track queries. +Both decoder query types benefit from self- and encoder-decoder attention on global frame-level features, thereby omitting any additional graph optimization and matching or modeling of motion and appearance. +TrackFormer represents a new tracking-by-attention paradigm and yields state-of-the-art performance on the task of multi-object tracking (MOT17) and segmentation (MOTS20). + +
+ TrackFormer casts multi-object tracking as a set prediction problem performing joint detection and tracking-by-attention. The architecture consists of a CNN for image feature extraction, a Transformer encoder for image feature encoding and a Transformer decoder which applies self- and encoder-decoder attention to produce output embeddings with bounding box and class information. +
+ +## Installation + +We refer to our [docs/INSTALL.md](INSTALL.md) for detailed installation instructions. + +## Train TrackFormer + +We refer to our [docs/TRAIN.md](TRAIN.md) for detailed training instructions. + +## Evaluate TrackFormer + +In order to evaluate TrackFormer on a multi-object tracking dataset, we provide the `src/track.py` script which supports several datasets and splits interchangle via the `dataset_name` argument (See `src/datasets/tracking/factory.py` for an overview of all datasets.) The default tracking configuration is specified in `cfgs/track.yaml`. To facilitate the reproducibility of our results, we provide evaluation metrics for both the train and test set. + +### MOT17 + +#### Private detections + +``` +python src/track.py with reid +``` +
+ +| MOT17 | MOTA | IDF1 | MT | ML | FP | FN | ID SW. | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| **Train** | 68.1 | 67.6 | 816 | 207 | 33549 | 71937 | 1935 | +| **Test** | 65.0 | 63.9 | 1074 | 324 | 70443 | 123552 | 3528 | + + +
+ +#### Public detections (DPM, FRCNN, SDP) + +``` +python src/track.py with \ + reid \ + public_detections=min_iou_0_5 \ + obj_detect_checkpoint_file=models/mots20_train_masks/checkpoint.pth +``` +
+ +| MOT17 | MOTA | IDF1 | MT | ML | FP | FN | ID SW. | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| **Train** | 67.2 | 66.9 | 663 | 294 | 14640 | 94122 | 1866 | +| **Test** | 62.5 | 60.7 | 702 | 632 | 32828 | 174921 | 3917 | + + +
+ +### MOTS20 + +``` +python src/track.py with \ + dataset_name=MOTS20-ALL \ + obj_detect_checkpoint_file=models/mots20_train_masks/checkpoint.pth +``` + +Our tracking script only applies MOT17 metrics evaluation but outputs MOTS20 mask prediction files. To evaluate these download the official [MOTChallengeEvalKit](https://github.com/dendorferpatrick/MOTChallengeEvalKit). + +
+ +| MOTS20 | sMOTSA | IDF1 | FP | FN | IDs | +| :---: | :---: | :---: | :---: | :---: | :---: | +| **Train** | -- | -- | -- | -- | -- | +| **Test** | 54.9 | 63.6 | 2233 | 7195 | 278 | + +
+ +### Demo + +To facilitate the application of TrackFormer, we provide a demo interface which allows for a quick processing of a given video sequence. + +``` +ffmpeg -i data/snakeboard/snakeboard.mp4 -vf fps=30 data/snakeboard/%06d.png + +python src/track.py with \ + dataset_name=DEMO \ + data_root_dir=data/snakeboard \ + output_dir=data/snakeboard \ + write_images=pretty +``` + +
+ Snakeboard demo +
+ +## Publication +If you use this software in your research, please cite our publication: + +``` +@InProceedings{meinhardt2021trackformer, + title={TrackFormer: Multi-Object Tracking with Transformers}, + author={Tim Meinhardt and Alexander Kirillov and Laura Leal-Taixe and Christoph Feichtenhofer}, + year={2021}, + eprint={2101.02702}, + archivePrefix={arXiv}, +} +``` diff --git a/docs/visdom.gif b/docs/visdom.gif new file mode 100644 index 0000000..d56597c Binary files /dev/null and b/docs/visdom.gif differ diff --git a/logs/.gitignore b/logs/.gitignore new file mode 100644 index 0000000..cc14549 --- /dev/null +++ b/logs/.gitignore @@ -0,0 +1,3 @@ +* +!visdom +!.gitignore diff --git a/logs/visdom/.gitignore b/logs/visdom/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/logs/visdom/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/models/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9532e1b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,123 @@ +argon2-cffi==20.1.0 +astroid==2.4.2 +async-generator==1.10 +attrs==19.3.0 +backcall==0.2.0 +bleach==3.2.3 +certifi==2020.4.5.2 +cffi==1.14.4 +chardet==3.0.4 +cloudpickle==1.6.0 +colorama==0.4.3 +cycler==0.10.0 +Cython==0.29.20 +decorator==4.4.2 +defusedxml==0.6.0 +docopt==0.6.2 +entrypoints==0.3 +filelock==3.0.12 +flake8==3.8.3 +flake8-import-order==0.18.1 +future==0.18.2 +gdown==3.12.2 +gitdb==4.0.5 +GitPython==3.1.3 +idna==2.9 +imageio==2.8.0 +importlib-metadata==1.6.1 +ipykernel==5.4.3 +ipython==7.19.0 +ipython-genutils==0.2.0 +ipywidgets==7.6.3 +isort==5.6.4 +jedi==0.18.0 +Jinja2==2.11.2 +jsonpatch==1.25 +jsonpickle==1.4.1 +jsonpointer==2.0 +jsonschema==3.2.0 +jupyter==1.0.0 +jupyter-client==6.1.11 +jupyter-console==6.2.0 +jupyter-core==4.7.0 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.0 +kiwisolver==1.2.0 +lap==0.4.0 +lapsolver==1.1.0 +lazy-object-proxy==1.4.3 +MarkupSafe==1.1.1 +matplotlib==3.2.1 +mccabe==0.6.1 +mistune==0.8.4 +more-itertools==8.4.0 +motmetrics==1.2.0 +munch==2.5.0 +nbclient==0.5.1 +nbconvert==6.0.7 +nbformat==5.1.2 +nest-asyncio==1.5.1 +networkx==2.4 +ninja==1.10.0.post2 +notebook==6.2.0 +numpy==1.18.5 +opencv-python==4.2.0.34 +packaging==20.4 +pandas==1.0.5 +pandocfilters==1.4.3 +parso==0.8.1 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==7.1.2 +pluggy==0.13.1 +prometheus-client==0.9.0 +prompt-toolkit==3.0.14 +ptyprocess==0.7.0 +py==1.8.2 +py-cpuinfo==6.0.0 +pyaml==20.4.0 +pycodestyle==2.6.0 +pycparser==2.20 +pyflakes==2.2.0 +Pygments==2.7.4 +pylint==2.6.0 +pyparsing==2.4.7 +pyrsistent==0.17.3 +PySocks==1.7.1 +pytest==5.4.3 +pytest-benchmark==3.2.3 +python-dateutil==2.8.1 +pytz==2020.1 +PyWavelets==1.1.1 +PyYAML==5.3.1 +pyzmq==19.0.1 +qtconsole==5.0.2 +QtPy==1.9.0 +requests==2.23.0 +sacred==0.8.1 +scikit-image==0.17.2 +scipy==1.4.1 +seaborn==0.10.1 +Send2Trash==1.5.0 +six==1.15.0 +smmap==3.0.4 +submitit==1.1.5 +terminado==0.9.2 +testpath==0.4.4 +tifffile==2020.6.3 +toml==0.10.2 +torchfile==0.1.0 +tornado==6.1 +tqdm==4.46.1 +traitlets==5.0.5 +typed-ast==1.4.1 +typing-extensions==3.7.4.3 +urllib3==1.25.9 +visdom==0.1.8.9 +wcwidth==0.2.5 +webencodings==0.5.1 +websocket-client==0.57.0 +widgetsnbextension==3.5.1 +wrapt==1.12.1 +xmltodict==0.12.0 +zipp==3.1.0 diff --git a/runs/actiongenome/actiongenome_detr+hoi.sh b/runs/actiongenome/actiongenome_detr+hoi.sh new file mode 100644 index 0000000..f7078f4 --- /dev/null +++ b/runs/actiongenome/actiongenome_detr+hoi.sh @@ -0,0 +1,17 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py \ + with deformable actiongenome hoi backbone=resnet101 \ + resume=models/actiongenome/ais_detr-resnet101_bf64ce0_run3/checkpoint_epoch20.pth \ + val_split=test_v500 \ + output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt + +# # eval +# python -u src/train.py \ +# with deformable actiongenome hoi backbone=resnet101 \ +# resume=models/actiongenome/ais_r101detr+hoi_4f459af_weakerdetr/checkpoint_epoch7.pth \ +# eval_only=True >> models/actiongenome/ais_r101detr+hoi_4f459af_weakerdetr/checkpoint_epoch7_evallog.txt diff --git a/runs/actiongenome/actiongenome_detr+tracking+hoi.sh b/runs/actiongenome/actiongenome_detr+tracking+hoi.sh new file mode 100644 index 0000000..514edf8 --- /dev/null +++ b/runs/actiongenome/actiongenome_detr+tracking+hoi.sh @@ -0,0 +1,18 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +# python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py \ +# with deformable actiongenome vsgg hoi_detection=True freeze_detr=False backbone=resnet101 \ +# resume=models/actiongenome/ais_r101detr+tracking+hoi_4f459af_run2/checkpoint_epoch5.pth \ +# hoi_use_temporal_dynamics=True clip_length=3 val_split=test_v500 \ +# output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt + +# eval +python -u src/train.py \ + with deformable actiongenome vsgg hoi_detection=True freeze_detr=True backbone=resnet101 \ + resume=models/actiongenome/ais_r101detr+tracking+hoi_4f459af_run2/jointly-tune/checkpoint_epoch4.pth \ + hoi_oracle_mode=False hoi_use_temporal_dynamics=True clip_length=3 eval_only=True hoi_relation_propagation_on_inference=False \ + > models/actiongenome/ais_r101detr+tracking+hoi_4f459af_run2/jointly-tune/checkpoint_epoch4_evallog.txt diff --git a/runs/actiongenome/actiongenome_detr+tracking.sh b/runs/actiongenome/actiongenome_detr+tracking.sh new file mode 100644 index 0000000..8e35cdc --- /dev/null +++ b/runs/actiongenome/actiongenome_detr+tracking.sh @@ -0,0 +1,10 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +# eval +# python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable actiongenome vsgg train_split=train_v1000 val_split=test_v500 resume=models/actiongenome/ais_vsgg_nohoi_lr\=1e-5_65c6820/checkpoint_epoch9.pth eval_only=True output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable actiongenome vsgg train_split=train_v1000 val_split=test_v200 resume=models/actiongenome/ais_detr-33d206b/checkpoint_epoch50.pth output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/actiongenome/actiongenome_detr.sh b/runs/actiongenome/actiongenome_detr.sh new file mode 100644 index 0000000..37a08cd --- /dev/null +++ b/runs/actiongenome/actiongenome_detr.sh @@ -0,0 +1,10 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable actiongenome backbone=resnet101 \ + lr=0.0001 lr_backbone=0.00001 batch_size=1 val_split=test_v500 save_model_interval=1 \ + lr_drop=20 resume=models/pretrained/r101_deformable_detr-checkpoint.pth \ + output_dir=$OUTPUT_DIR > $OUTPUT_DIR/log.txt diff --git a/runs/consistent_pairing_base.sh b/runs/consistent_pairing_base.sh new file mode 100644 index 0000000..137fb8e --- /dev/null +++ b/runs/consistent_pairing_base.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u src/train.py with deformable tracking mot17_cross_val consistent_pairing output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt \ No newline at end of file diff --git a/runs/consistent_pairing_crowdhuman.sh b/runs/consistent_pairing_crowdhuman.sh new file mode 100644 index 0000000..0547e5d --- /dev/null +++ b/runs/consistent_pairing_crowdhuman.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable tracking crowdhuman consistent_pairing output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt \ No newline at end of file diff --git a/runs/consistent_pairing_mot17.sh b/runs/consistent_pairing_mot17.sh new file mode 100644 index 0000000..91012bf --- /dev/null +++ b/runs/consistent_pairing_mot17.sh @@ -0,0 +1,8 @@ +OUTPUT_DIR=$1 +RESUME_MODEL=$2 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable tracking mot17 consistent_pairing resume=$RESUME_MODEL output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt \ No newline at end of file diff --git a/runs/consistent_pairing_ngpus.sh b/runs/consistent_pairing_ngpus.sh new file mode 100644 index 0000000..040aa7c --- /dev/null +++ b/runs/consistent_pairing_ngpus.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable tracking mot17_cross_val consistent_pairing output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt \ No newline at end of file diff --git a/runs/demo_CP.sh b/runs/demo_CP.sh new file mode 100644 index 0000000..63edb6f --- /dev/null +++ b/runs/demo_CP.sh @@ -0,0 +1,14 @@ +demo_video_name=MOT17-02-FRCNN +model_folder=$1 + +python -u src/track.py with \ + dataset_name=$demo_video_name \ + output_dir="${model_folder}/demos" \ + frame_range.start=0.5 \ + obj_detect_checkpoint_file="${model_folder}/checkpoint.pth" \ + verbose=True \ + write_images=debug + +ffmpeg -framerate 8 -start_number 301 -i "${model_folder}/demos/${demo_video_name}/${demo_video_name}/%6d.jpg" -pix_fmt yuv420p -c:v libx264 "${model_folder}/demos/out.mp4" + +echo 'Done!' diff --git a/runs/detr_crowdhuman.sh b/runs/detr_crowdhuman.sh new file mode 100644 index 0000000..1ba6d66 --- /dev/null +++ b/runs/detr_crowdhuman.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable crowdhuman output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/tmp.sh b/runs/tmp.sh new file mode 100644 index 0000000..a033ed9 --- /dev/null +++ b/runs/tmp.sh @@ -0,0 +1,22 @@ +cd /218019030/projects/VideoSG-on-trackformer/ + +# # eval +# python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env \ +# src/train.py with deformable vidhoi vsgg \ +# resume='models/vidhoi/ais_ORACLE_MODE_detr+tracking+hoi_59c6a9c+union_feat+hoi_use_temporal_dynamics/checkpoint_epoch5.pth' \ +# hoi_detection=True freeze_detr=True clip_length=3 \ +# hoi_oracle_mode=True hoi_use_temporal_dynamics=True hoi_oracle_mode_use_roialign_union_feat=True \ +# eval_only=True hoi_relation_propagation_on_inference=True > 'models/vidhoi/ais_ORACLE_MODE_detr+tracking+hoi_59c6a9c+union_feat+hoi_use_temporal_dynamics/relation_prop_eval_epoch5.log.txt' + + +# coco detr_resnet101 train +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable \ + backbone=resnet101 batch_size=1 resume='models/coco/ais_detr-resnet101_59c6a9d/checkpoint_epoch10.pth' \ + output_dir='models/coco/ais_detr-resnet101_59c6a9d' lr_drop=20 >> 'models/coco/ais_detr-resnet101_59c6a9d/log.txt' + + +# # python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable actiongenome backbone=resnet101 \ +# # resume=models/actiongenome/ais_detr-resnet101_59c6a9d/checkpoint.pth eval_only=True > models/actiongenome/ais_detr-resnet101_59c6a9d/eval_all_log.txt + +# python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable actiongenome \ +# resume=models/actiongenome/ais_detr_59c6a9c/checkpoint.pth eval_only=True > models/actiongenome/ais_detr_59c6a9c/eval_all_log.txt diff --git a/runs/trackformer_base.sh b/runs/trackformer_base.sh new file mode 100644 index 0000000..4229823 --- /dev/null +++ b/runs/trackformer_base.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u src/train.py with deformable tracking mot17_cross_val clip_length=3 output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt \ No newline at end of file diff --git a/runs/trackformer_crowdhuman.sh b/runs/trackformer_crowdhuman.sh new file mode 100644 index 0000000..f0c4040 --- /dev/null +++ b/runs/trackformer_crowdhuman.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11115 --use_env src/train.py with deformable tracking crowdhuman clip_length=None output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/trackformer_mot17.sh b/runs/trackformer_mot17.sh new file mode 100644 index 0000000..6bd8548 --- /dev/null +++ b/runs/trackformer_mot17.sh @@ -0,0 +1,8 @@ +OUTPUT_DIR=$1 +RESUME_MODEL=$2 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11115 --use_env src/train.py with deformable tracking mot17 clip_length=2 output_dir=$OUTPUT_DIR resume=$RESUME_MODEL >> $OUTPUT_DIR/log.txt diff --git a/runs/trackformer_ngpus.sh b/runs/trackformer_ngpus.sh new file mode 100644 index 0000000..8d8cd81 --- /dev/null +++ b/runs/trackformer_ngpus.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11115 --use_env src/train.py with deformable tracking mot17_cross_val clip_length=2 output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/vidhoi/vidhoi_detr+hoi.sh b/runs/vidhoi/vidhoi_detr+hoi.sh new file mode 100644 index 0000000..ed86c30 --- /dev/null +++ b/runs/vidhoi/vidhoi_detr+hoi.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable vidhoi hoi resume=models/vidhoi/ais_detr_b82a0c8/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/vidhoi/vidhoi_detr+tracking+hoi.sh b/runs/vidhoi/vidhoi_detr+tracking+hoi.sh new file mode 100644 index 0000000..d7d0883 --- /dev/null +++ b/runs/vidhoi/vidhoi_detr+tracking+hoi.sh @@ -0,0 +1,42 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +# DETR+Tracking+HOI +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env \ + src/train.py with deformable vidhoi vsgg hoi_detection=True freeze_detr=True \ + resume='models/vidhoi/pretrained/jd42_detr+tracking_2a5d6f7+clip_length=3/checkpoint_epoch3.pth' \ + clip_length=3 hoi_use_temporal_dynamics=True hoi_relation_propagation_on_inference=True \ + output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt + +# # DETR+Tracking+HOI^{TDE} +# python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env \ +# src/train.py with deformable vidhoi vsgg \ +# resume='models/vidhoi/pretrained/jd42_detr+tracking_2a5d6f7+clip_length=3/checkpoint_epoch3.pth' \ +# hoi_detection=True \ +# freeze_detr=True \ +# hoi_use_temporal_dynamics=True \ +# clip_length=3 \ +# output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt + +# # DETR+Tracking+HOI^{TDE+JFT} +# python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env \ +# src/train.py with deformable vidhoi vsgg \ +# resume='models/vidhoi/ais_detr+tracking+hoi_7adaa8f+hoi_use_temporal_dynamics/checkpoint_epoch7.pth' \ +# hoi_detection=True \ +# freeze_detr=False \ +# hoi_use_temporal_dynamics=True \ +# clip_length=3 \ +# output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt + + +# # eval +# python -u -m torch.distributed.launch --nproc_per_node=2 --master_port=11112 --use_env \ +# src/train.py with deformable vidhoi vsgg \ +# resume='models/vidhoi/ais_detr+tracking+hoi_7adaa8f+hoi_use_temporal_dynamics/jointly-tune/checkpoint_epoch4.pth' \ +# hoi_detection=True freeze_detr=True \ +# hoi_use_temporal_dynamics=True hoi_relation_propagation_on_inference=True \ +# eval_only=True \ +# output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/vidhoi/vidhoi_detr+tracking.sh b/runs/vidhoi/vidhoi_detr+tracking.sh new file mode 100644 index 0000000..eed30d5 --- /dev/null +++ b/runs/vidhoi/vidhoi_detr+tracking.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable vidhoi vsgg resume=models/vidhoi/ais_detr_b82a0c8/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/vidhoi/vidhoi_detr.sh b/runs/vidhoi/vidhoi_detr.sh new file mode 100644 index 0000000..971d10e --- /dev/null +++ b/runs/vidhoi/vidhoi_detr.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train.py with deformable vidhoi lr=0.00005 lr_backbone=0.00001 output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/runs/vidhoi/vidhoi_frcnn.sh b/runs/vidhoi/vidhoi_frcnn.sh new file mode 100644 index 0000000..8e05d1a --- /dev/null +++ b/runs/vidhoi/vidhoi_frcnn.sh @@ -0,0 +1,7 @@ +OUTPUT_DIR=$1 + +echo "OUTPUT_DIR=${OUTPUT_DIR}" +cd /218019030/projects/VideoSG-on-trackformer/ +mkdir -p $OUTPUT_DIR + +python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=11112 --use_env src/train_frcnn.py with frcnn vidhoi batch_size=2 output_dir=$OUTPUT_DIR >> $OUTPUT_DIR/log.txt diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f1447a9 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup(name='trackformer', + packages=['trackformer'], + package_dir={'':'src'}, + version='0.0.1', + install_requires=[],) diff --git a/src/STTran/__init__.py b/src/STTran/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/STTran/readme.md b/src/STTran/readme.md new file mode 100644 index 0000000..2a46f46 --- /dev/null +++ b/src/STTran/readme.md @@ -0,0 +1,3 @@ +Modifed from `https://github.com/yrcong/STTran` + +Requires `pytorc<=1.5` !! diff --git a/src/STTran/sttran.py b/src/STTran/sttran.py new file mode 100644 index 0000000..a0e9ce8 --- /dev/null +++ b/src/STTran/sttran.py @@ -0,0 +1,243 @@ +""" +Let's get the relationships yo +""" + +import numpy as np +import torch +import torch.nn as nn + +from .word_vectors import obj_edge_vectors +from .transformer import transformer +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +from torch.jit.annotations import Tuple, List +from collections import OrderedDict +from torchvision.models.detection.image_list import ImageList +from torchvision.ops.boxes import box_iou + +def normalize_box(boxes, image_size): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + device = boxes.device + H, W = image_size + + wh = boxes[:, 2:] - boxes[:, :2] + 1.0 + xywh = torch.cat((boxes[:, :2] + 0.5 * wh, wh), 1) / torch.tensor([[W,H,W,H]]).to(device) + return xywh + +def multilabel_focal_loss(inputs, targets, gamma=2): + probs = inputs.sigmoid() + + # focal loss to balance positive/negative + pos_inds = targets.eq(1).float() + neg_inds = targets.lt(1).float() + pos_loss = torch.log(probs) * torch.pow(1 - probs, gamma) * pos_inds + neg_loss = torch.log(1 - probs) * torch.pow(probs, gamma) * neg_inds + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + # normalize + num_pos = pos_inds.float().sum() + if num_pos == 0: + loss = -neg_loss + else: + loss = -(pos_loss + neg_loss) / num_pos + + return loss + +class STTran(nn.Module): + + def __init__(self, args, frcnn, obj_classes, enc_layer_num=1, dec_layer_num=3): + + """ + :param classes: Object classes + :param rel_classes: Relationship classes. None if were not using rel mode + :param mode: (sgcls, predcls, or sgdet) + """ + super(STTran, self).__init__() + self.args = args + self.frcnn = frcnn + for p in self.parameters(): + p.requires_grad_(False) + + self.obj_classes = obj_classes + assert args.sgg_mode in ('sgdet', 'sgcls', 'predcls') + self.mode = args.sgg_mode + + ################################### + vis_dim, spatial_dim, hidden_dim, semantic_dim = 1024, 128, self.args.hidden_dim, 200 + self.pos_embed = nn.Sequential(nn.Linear(4, spatial_dim), + nn.ReLU(inplace=True), + nn.Dropout(0.1)) + + self.subj_fc = nn.Linear(vis_dim+spatial_dim, hidden_dim) + self.obj_fc = nn.Linear(vis_dim+spatial_dim, hidden_dim) + self.vr_fc = nn.Linear(256*7*7, hidden_dim) + + embed_vecs = obj_edge_vectors(obj_classes, wv_type='glove.6B', wv_dir='data/glove', wv_dim=semantic_dim) + self.obj_embed = nn.Embedding(len(obj_classes), semantic_dim) + self.obj_embed.weight.data = embed_vecs.clone() + + self.obj_embed2 = nn.Embedding(len(obj_classes), semantic_dim) + self.obj_embed2.weight.data = embed_vecs.clone() + + rel_dim = hidden_dim*3 + semantic_dim*2 + self.rel_input_fc = nn.Linear(rel_dim, hidden_dim) + self.glocal_transformer = transformer(enc_layer_num=enc_layer_num, dec_layer_num=dec_layer_num, embed_dim=hidden_dim, nhead=8, + dim_feedforward=2048, dropout=0.1, mode='latter') + + self.rel_compress = nn.Linear(hidden_dim, args.num_relations) + + @torch.no_grad() + def _get_detection_results(self, images, targets, IoU_threshold=0.5, K=16): + # from torchvision GeneralizedRCNN forward + self.frcnn.eval() + device = images.device + original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + image_list, _ = self.frcnn.transform(images, None) + features = self.frcnn.backbone(image_list.tensors) + + if self.mode == 'predcls': + detections = [{'boxes': t['boxes'], + 'labels': t['labels'], + 'scores': torch.ones(len(t['labels'])).to(device)} for t in targets] + org_h, org_w = original_image_sizes[0]; det_h, det_w = image_list.image_sizes[0] + ratios = torch.tensor([det_w/org_w, det_h/org_h, det_w/org_w, det_h/org_h]).unsqueeze(0).to(device) + boxes_for_roi_pool = [d['boxes'] * ratios for d in detections] + elif self.mode == 'sgdet': + proposals, _ = self.frcnn.rpn(image_list, features, None) + detections, _ = self.frcnn.roi_heads(features, proposals, image_list.image_sizes, None) + boxes_for_roi_pool = [d['boxes'] for d in detections] + detections = self.frcnn.transform.postprocess(detections, image_list.image_sizes, original_image_sizes) + + # box features + det_nums = [len(d['boxes']) for d in detections] + res = self.frcnn.roi_heads.box_roi_pool(features, boxes_for_roi_pool, image_list.image_sizes) + box_features = self.frcnn.roi_heads.box_head(res).split(det_nums, dim=0) + + # relation pairs + box_idxs, rel_pairs, rel_im_idxs, rel_union_boxes, rel_targets = [], [], [], [], [] + for image_id, (det_num, dets, tgt) in enumerate(zip(det_nums, detections, targets)): + box_idxs.append(torch.ones(det_num).to(device) * image_id) + + # rel pairs + rel_map = torch.zeros(det_num, det_num).to(device) + if self.training: + if self.mode == 'sgdet': + if len(tgt['labels']) < 1: # special case: no relations + img_rel_pairs = torch.zeros((0, 2), device=device).long() + rel_targets.append(torch.zeros((0, self.args.num_relations), device=device).float()) + else: + # detection matching with groundtruths + unmatched_IDX = -1 + label_match = (dets['labels'].unsqueeze(1) == tgt['labels'].unsqueeze(0)) + IoUs = box_iou(dets['boxes'], tgt['boxes']) + IoUs[~label_match] = 0 + + overlaps, det2gt_ids = IoUs.max(dim=1) + det2gt_ids[overlaps < IoU_threshold] = unmatched_IDX # unmatched detections -1 + + # sampling & relation targets + rel_map[(dets['labels']==1) & (det2gt_ids!=unmatched_IDX)] = 1 # subject is a detected human + rel_map.fill_diagonal_(0) + img_rel_pairs = rel_map.nonzero() + + rel_obj_matched_ids, perm = det2gt_ids[img_rel_pairs[:,1]].sort(descending=True) + img_rel_pairs = img_rel_pairs[perm[:K]] # sampling + + gt_obj_ids = det2gt_ids[img_rel_pairs] + rel_tgt_map = tgt['relation_map'][gt_obj_ids[:, 0], gt_obj_ids[:, 1]] # set annotated relations + rel_tgt_map[rel_obj_matched_ids[:K]==unmatched_IDX] = 0 # negative relations + rel_targets.append(rel_tgt_map) + elif self.mode == 'predcls': + rel_map[dets['labels']==1] = 1 # subject is human + rel_map.fill_diagonal_(0) + img_rel_pairs = rel_map.nonzero() + + # relation targets + rel_tgt_map = tgt['relation_map'][img_rel_pairs[:, 0], img_rel_pairs[:, 1]] + rel_targets.append(rel_tgt_map) + else: + rel_map[(dets['labels']==1) & (dets['scores']>0.2)] = 1 # subject is human + rel_map[:, dets['scores']<0.2] = 0 + rel_map.fill_diagonal_(0) + img_rel_pairs = rel_map.nonzero() + + rel_pairs.append(img_rel_pairs + sum(det_nums[:image_id])) + rel_im_idxs.append(torch.ones(len(img_rel_pairs)).to(device) * image_id) + + # union boxes + subj_boxes, obj_boxes = boxes_for_roi_pool[image_id][img_rel_pairs[:,0]], boxes_for_roi_pool[image_id][img_rel_pairs[:,1]] + union_boxes = torch.cat([torch.min(subj_boxes[:, :2], obj_boxes[:, :2]), torch.max(subj_boxes[:, 2:], obj_boxes[:, 2:])], dim=-1) + rel_union_boxes.append(union_boxes) + + res = { + 'box_nums': det_nums, + 'boxes': torch.cat([torch.cat([ids.unsqueeze(-1), d['boxes']], dim=-1) for ids, d in zip(box_idxs, detections)], dim=0), + 'scores': torch.cat([d['scores'] for d in detections], dim=0), + 'labels': torch.cat([d['labels'] for d in detections], dim=0), + 'box_features': torch.cat(box_features, dim=0), + 'rel_pair_nums': [len(p) for p in rel_pairs], + 'rel_pair_idxs': torch.cat(rel_pairs, dim=0), + 'rel_im_idxs': torch.cat(rel_im_idxs, dim=0), + 'rel_union_feats': self.frcnn.roi_heads.box_roi_pool(features, rel_union_boxes, image_list.image_sizes), + 'image_org_size': torch.tensor(original_image_sizes[0]).to(device) + } + if self.training: res.update({'rel_targets': torch.cat(rel_targets, dim=0)}) + return res + + def forward(self, images, targets=None): + entry = self._get_detection_results(images, targets) + rel_pairs = entry['rel_pair_idxs'] + + # visual part + pos_features = self.pos_embed(normalize_box(entry['boxes'][:, 1:], image_size=entry['image_org_size'])) + obj_features = torch.cat([entry['box_features'], pos_features], dim=-1) + subj_rep = self.subj_fc(obj_features[rel_pairs[:, 0]]) + obj_rep = self.obj_fc(obj_features[rel_pairs[:, 1]]) + vr = self.vr_fc(entry['rel_union_feats'].view(-1, 256*7*7)) + x_visual = torch.cat((subj_rep, obj_rep, vr), 1) + + # semantic part + subj_emb = self.obj_embed(entry['labels'][rel_pairs[:, 0]]) + obj_emb = self.obj_embed2(entry['labels'][rel_pairs[:, 1]]) + x_semantic = torch.cat((subj_emb, obj_emb), 1) + + rel_features = self.rel_input_fc(torch.cat((x_visual, x_semantic), dim=1)) + if len(entry['rel_im_idxs']) > 0: # Spatial-Temporal Transformer + rel_features, _, _ = \ + self.glocal_transformer(features=rel_features, im_idx=self.get_continuous_image_idxs(entry['rel_im_idxs'].clone())) + entry["rel_logits"] = self.rel_compress(rel_features) + + if self.training: + relation_cls_loss = multilabel_focal_loss(entry["rel_logits"], entry["rel_targets"]) + return {'relation_cls_loss': relation_cls_loss} + else: + return entry + + def get_continuous_image_idxs(self, org_idxes): + sorted_unique_idxes = sorted(org_idxes.unique().tolist()) + if org_idxes[-1] != len(sorted_unique_idxes)-1: + for new_id, org_id in enumerate(sorted_unique_idxes): + org_idxes[org_idxes==org_id] = new_id + return org_idxes + +def build_frcnn(args): + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + in_features = model.roi_heads.box_predictor.cls_score.in_features + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, args.num_classes) + return model + +def build_sttran(args, obj_classes): + frcnn = build_frcnn(args) + sttran = STTran(args, frcnn, obj_classes) + return sttran diff --git a/src/STTran/transformer.py b/src/STTran/transformer.py new file mode 100644 index 0000000..ed2a42c --- /dev/null +++ b/src/STTran/transformer.py @@ -0,0 +1,191 @@ +import torch +import torch.nn as nn +import copy + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, embed_dim=1936, nhead=4, dim_feedforward=2048, dropout=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + + self.linear1 = nn.Linear(embed_dim, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, src, input_key_padding_mask): + # local attention + src2, local_attention_weights = self.self_attn(src, src, src, key_padding_mask=input_key_padding_mask) + + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(nn.functional.relu(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + + return src, local_attention_weights + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, embed_dim=1936, nhead=4, dim_feedforward=2048, dropout=0.1): + super().__init__() + + self.multihead2 = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + + self.linear1 = nn.Linear(embed_dim, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, embed_dim) + + + self.norm3 = nn.LayerNorm(embed_dim) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + def forward(self, global_input, input_key_padding_mask, position_embed): + + tgt2, global_attention_weights = self.multihead2(query=global_input+position_embed, key=global_input+position_embed, + value=global_input, key_padding_mask=input_key_padding_mask) + tgt = global_input + self.dropout2(tgt2) + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(nn.functional.relu(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + + return tgt, global_attention_weights + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, input, input_key_padding_mask): + output = input + weights = torch.zeros([self.num_layers, output.shape[1], output.shape[0], output.shape[0]]).to(output.device) + + for i, layer in enumerate(self.layers): + output, local_attention_weights = layer(output, input_key_padding_mask) + weights[i] = local_attention_weights + if self.num_layers > 0: + return output, weights + else: + return output, None + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, embed_dim): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + + def forward(self, global_input, input_key_padding_mask, position_embed): + + output = global_input + weights = torch.zeros([self.num_layers, output.shape[1], output.shape[0], output.shape[0]]).to(output.device) + + for i, layer in enumerate(self.layers): + output, global_attention_weights = layer(output, input_key_padding_mask, position_embed) + weights[i] = global_attention_weights + + if self.num_layers>0: + return output, weights + else: + return output, None + + +class transformer(nn.Module): + ''' Spatial Temporal Transformer + local_attention: spatial encoder + global_attention: temporal decoder + position_embedding: frame encoding (window_size*dim) + mode: both--use the features from both frames in the window + latter--use the features from the latter frame in the window + ''' + def __init__(self, enc_layer_num=1, dec_layer_num=3, embed_dim=1936, nhead=8, dim_feedforward=2048, + dropout=0.1, mode=None): + super(transformer, self).__init__() + self.mode = mode + + encoder_layer = TransformerEncoderLayer(embed_dim=embed_dim, nhead=nhead, dim_feedforward=dim_feedforward, + dropout=dropout) + self.local_attention = TransformerEncoder(encoder_layer, enc_layer_num) + + decoder_layer = TransformerDecoderLayer(embed_dim=embed_dim, nhead=nhead, dim_feedforward=dim_feedforward, + dropout=dropout) + + self.global_attention = TransformerDecoder(decoder_layer, dec_layer_num, embed_dim) + + self.position_embedding = nn.Embedding(2, embed_dim) #present and next frame + nn.init.uniform_(self.position_embedding.weight) + + + def forward(self, features, im_idx): + rel_idx = torch.arange(im_idx.shape[0]) + + l = torch.sum(im_idx == torch.mode(im_idx)[0]) # the highest box number in the single frame + b = int(im_idx[-1] + 1) + rel_input = torch.zeros([l, b, features.shape[1]]).to(features.device) + masks = torch.zeros([b, l], dtype=torch.uint8).to(features.device) + # TODO Padding/Mask maybe don't need for-loop + for i in range(b): + rel_input[:torch.sum(im_idx == i), i, :] = features[im_idx == i] + masks[i, torch.sum(im_idx == i):] = 1 + + # spatial encoder + local_output, local_attention_weights = self.local_attention(rel_input, masks) + local_output = (local_output.permute(1, 0, 2)).contiguous().view(-1, features.shape[1])[masks.view(-1) == 0] + + global_input = torch.zeros([l * 2, b - 1, features.shape[1]]).to(features.device) + position_embed = torch.zeros([l * 2, b - 1, features.shape[1]]).to(features.device) + idx = -torch.ones([l * 2, b - 1]).to(features.device) + idx_plus = -torch.ones([l * 2, b - 1], dtype=torch.long).to(features.device) #TODO + + # sliding window size = 2 + for j in range(b - 1): + global_input[:torch.sum((im_idx == j) + (im_idx == j + 1)), j, :] = local_output[(im_idx == j) + (im_idx == j + 1)] # 把相邻两帧目标拼在一起,然后做 self-attention + idx[:torch.sum((im_idx == j) + (im_idx == j + 1)), j] = im_idx[(im_idx == j) + (im_idx == j + 1)] + idx_plus[:torch.sum((im_idx == j) + (im_idx == j + 1)), j] = rel_idx[(im_idx == j) + (im_idx == j + 1)] #TODO + + position_embed[:torch.sum(im_idx == j), j, :] = self.position_embedding.weight[0] + position_embed[torch.sum(im_idx == j):torch.sum(im_idx == j)+torch.sum(im_idx == j+1), j, :] = self.position_embedding.weight[1] + + global_masks = (torch.sum(global_input.view(-1, features.shape[1]),dim=1) == 0).view(l * 2, b - 1).permute(1, 0) + # temporal decoder + global_output, global_attention_weights = self.global_attention(global_input, global_masks, position_embed) + + output = torch.zeros_like(features) + + if self.mode == 'both': + # both + for j in range(b - 1): + if j == 0: + output[im_idx == j] = global_output[:, j][idx[:, j] == j] + + if j == b - 2: + output[im_idx == j+1] = global_output[:, j][idx[:, j] == j+1] + else: + output[im_idx == j + 1] = (global_output[:, j][idx[:, j] == j + 1] + + global_output[:, j + 1][idx[:, j + 1] == j + 1]) / 2 + + elif self.mode == 'latter': + # later + for j in range(b - 1): + if j == 0: + output[im_idx == j] = global_output[:, j][idx[:, j] == j] + + output[im_idx == j + 1] = global_output[:, j][idx[:, j] == j + 1] + + return output, global_attention_weights, local_attention_weights + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + diff --git a/src/STTran/word_vectors.py b/src/STTran/word_vectors.py new file mode 100644 index 0000000..4aa779d --- /dev/null +++ b/src/STTran/word_vectors.py @@ -0,0 +1,130 @@ +""" +Adapted from PyTorch's text library. +""" + +import array +import os +import zipfile + +import six +import torch +from six.moves.urllib.request import urlretrieve +from tqdm import tqdm +import sys + +def obj_edge_vectors(names, wv_type='glove.6B', wv_dir=None, wv_dim=300): + wv_dict, wv_arr, wv_size = load_word_vectors(wv_dir, wv_type, wv_dim) + + vectors = torch.Tensor(len(names), wv_dim) + vectors.normal_(0,1) + + for i, token in enumerate(names): + wv_index = wv_dict.get(token.split('/')[0], None) + if wv_index is not None: + vectors[i] = wv_arr[wv_index] + else: + # Try the longest word (hopefully won't be a preposition + lw_token = sorted(token.split(' '), key=lambda x: len(x), reverse=True)[0] + print("{} -> {} ".format(token, lw_token)) + wv_index = wv_dict.get(lw_token, None) + if wv_index is not None: + vectors[i] = wv_arr[wv_index] + else: + print("fail on {}".format(token)) + + return vectors + +URL = { + 'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', + 'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', + 'glove.twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', + 'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip', + } + + +def load_word_vectors(root, wv_type, dim): + """Load word vectors from a path, trying .pt, .txt, and .zip extensions.""" + if isinstance(dim, int): + dim = str(dim) + 'd' + fname = os.path.join(root, wv_type + '.' + dim) + if os.path.isfile(fname + '.pt'): + fname_pt = fname + '.pt' + print('loading word vectors from', fname_pt) + try: + return torch.load(fname_pt) + except Exception as e: + print(""" + Error loading the model from {} + + This could be because this code was previously run with one + PyTorch version to generate cached data and is now being + run with another version. + You can try to delete the cached files on disk (this file + and others) and re-running the code + + Error message: + --------- + {} + """.format(fname_pt, str(e))) + sys.exit(-1) + if os.path.isfile(fname + '.txt'): + fname_txt = fname + '.txt' + cm = open(fname_txt, 'rb') + cm = [line for line in cm] + elif os.path.basename(wv_type) in URL: + url = URL[wv_type] + print('downloading word vectors from {}'.format(url)) + filename = os.path.basename(fname) + if not os.path.exists(root): + os.makedirs(root) + with tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t: + fname, _ = urlretrieve(url, fname, reporthook=reporthook(t)) + with zipfile.ZipFile(fname, "r") as zf: + print('extracting word vectors into {}'.format(root)) + zf.extractall(root) + if not os.path.isfile(fname + '.txt'): + raise RuntimeError('no word vectors of requested dimension found') + return load_word_vectors(root, wv_type, dim) + else: + raise RuntimeError('unable to load word vectors') + + wv_tokens, wv_arr, wv_size = [], array.array('d'), None + if cm is not None: + for line in tqdm(range(len(cm)), desc="loading word vectors from {}".format(fname_txt)): + entries = cm[line].strip().split(b' ') + word, entries = entries[0], entries[1:] + if wv_size is None: + wv_size = len(entries) + try: + if isinstance(word, six.binary_type): + word = word.decode('utf-8') + except: + print('non-UTF8 token', repr(word), 'ignored') + continue + wv_arr.extend(float(x) for x in entries) + wv_tokens.append(word) + + wv_dict = {word: i for i, word in enumerate(wv_tokens)} + wv_arr = torch.Tensor(wv_arr).view(-1, wv_size) + ret = (wv_dict, wv_arr, wv_size) + torch.save(ret, fname + '.pt') + return ret + +def reporthook(t): + """https://github.com/tqdm/tqdm""" + last_b = [0] + + def inner(b=1, bsize=1, tsize=None): + """ + b: int, optionala + Number of blocks just transferred [default: ĺeftright]. + bsize: int, optional + Size of each block (in tqdm units) [default: ĺeftright]. + tsize: int, optional + Total size (in tqdm units). If [default: None] remains unchanged. + """ + if tsize is not None: + t.total = tsize + t.update((b - last_b[0]) * bsize) + last_b[0] = b + return inner diff --git a/src/base_trackers/README.md b/src/base_trackers/README.md new file mode 100644 index 0000000..f95a68e --- /dev/null +++ b/src/base_trackers/README.md @@ -0,0 +1,3 @@ +Refer: +* https://github.com/adipandas/multi-object-tracker +* https://github1s.com/ifzhang/ByteTrack \ No newline at end of file diff --git a/src/base_trackers/__init__.py b/src/base_trackers/__init__.py new file mode 100644 index 0000000..caa36d5 --- /dev/null +++ b/src/base_trackers/__init__.py @@ -0,0 +1,4 @@ +from .tracker import Tracker as CentroidTracker +from .sort_tracker import SORT +from .iou_tracker import IOUTracker +from .byte_tracker import BYTETracker diff --git a/src/base_trackers/byte_tracker.py b/src/base_trackers/byte_tracker.py new file mode 100644 index 0000000..490ae92 --- /dev/null +++ b/src/base_trackers/byte_tracker.py @@ -0,0 +1,322 @@ +import numpy as np +from .bytetrack.kalman_filter import KalmanFilter +from .bytetrack import matching +from .bytetrack.basetrack import BaseTrack, TrackState +from .utils import box_xywh_to_xyxy + +class STrack(BaseTrack): + shared_kalman = KalmanFilter() + def __init__(self, tlwh, score, class_id): + + # wait activate + self._tlwh = np.asarray(tlwh, dtype=np.float) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.score = score + self.tracklet_len = 0 + + # states + self.class_id = class_id + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[7] = 0 + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + # self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh) + ) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + self.score = new_track.score + + def update(self, new_track, frame_id): + """ + Update a matched track + :type new_track: STrack + :type frame_id: int + :type update_feature: bool + :return: + """ + self.frame_id = frame_id + self.tracklet_len += 1 + + new_tlwh = new_track.tlwh + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) + self.state = TrackState.Tracked + self.is_activated = True + + self.score = new_track.score + + @property + # @jit(nopython=True) + def tlwh(self): + """Get current position in bounding box format `(top left x, top left y, + width, height)`. + """ + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + @property + # @jit(nopython=True) + def tlbr(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @staticmethod + # @jit(nopython=True) + def tlwh_to_xyah(tlwh): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + def to_xyah(self): + return self.tlwh_to_xyah(self.tlwh) + + @staticmethod + # @jit(nopython=True) + def tlbr_to_tlwh(tlbr): + ret = np.asarray(tlbr).copy() + ret[2:] -= ret[:2] + return ret + + @staticmethod + # @jit(nopython=True) + def tlwh_to_tlbr(tlwh): + ret = np.asarray(tlwh).copy() + ret[2:] += ret[:2] + return ret + + def __repr__(self): + return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) + + def output(self): + """ + Track data output in VISDRONE Challenge format with tuple as + `(frame_index, target_id, bbox_left, bbox_top, bbox_width, bbox_height, score, object_category, + truncation, occlusion)`. + """ + bbox = self.tlwh + mot_tuple = ( + self.frame_id, self.track_id, bbox[0], bbox[1], bbox[2], bbox[3], + self.score, self.class_id, -1, -1 + ) + return mot_tuple + +class BYTETracker(object): + def __init__(self, max_lost=2, confidence_threshold=0.7, iou_threshold=0.3): + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + + self.frame_id = 0 + self.max_time_lost = max_lost + self.kalman_filter = KalmanFilter() + + # hyper parameters + self.det_thresh = confidence_threshold + self.track_thresh = self.det_thresh - 0.1 + self.iou_threshold = iou_threshold + + def update(self, bboxes, scores, class_ids): + self.frame_id += 1 + bboxes = box_xywh_to_xyxy(bboxes) + + activated_starcks = [] + refind_stracks = [] + lost_stracks = [] + removed_stracks = [] + + remain_inds = scores > self.track_thresh + inds_low = scores > 0.1 + inds_high = scores < self.track_thresh + + inds_second = np.logical_and(inds_low, inds_high) + dets_second = bboxes[inds_second] + dets = bboxes[remain_inds] + scores_keep = scores[remain_inds] + scores_second = scores[inds_second] + + if len(dets) > 0: + '''Detections''' + detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s, class_id=c) for (tlbr, s, c) in zip(dets, scores_keep, class_ids[remain_inds])] + else: + detections = [] + + ''' Add newly detected tracklets to tracked_stracks''' + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + unconfirmed.append(track) + else: + tracked_stracks.append(track) + + ''' Step 2: First association, with high score detection boxes''' + strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + STrack.multi_predict(strack_pool) + dists = matching.iou_distance(strack_pool, detections) + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=1-self.iou_threshold) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(detections[idet], self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + ''' Step 3: Second association, with low score detection boxes''' + # association the untrack to the low score detections + if len(dets_second) > 0: + detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for (tlbr, s, c) in zip(dets_second, scores_second, class_ids[inds_second])] + else: + detections_second = [] + r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] + dists = matching.iou_distance(r_tracked_stracks, detections_second) + matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + + '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' + detections = [detections[i] for i in u_detection] + dists = matching.iou_distance(unconfirmed, detections) + matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_starcks.append(unconfirmed[itracked]) + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + """ Step 4: Init new stracks""" + for inew in u_detection: + track = detections[inew] + if track.score < self.det_thresh: + continue + track.activate(self.kalman_filter, self.frame_id) + activated_starcks.append(track) + """ Step 5: Update state""" + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] + self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) + self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + + # return + output_stracks = [track.output() for track in self.tracked_stracks if track.is_activated] + return output_stracks + + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.track_id] = t + for t in tlistb: + tid = t.track_id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if not i in dupa] + resb = [t for i, t in enumerate(stracksb) if not i in dupb] + return resa, resb diff --git a/src/base_trackers/bytetrack/__init__.py b/src/base_trackers/bytetrack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/base_trackers/bytetrack/basetrack.py b/src/base_trackers/bytetrack/basetrack.py new file mode 100644 index 0000000..a7130b5 --- /dev/null +++ b/src/base_trackers/bytetrack/basetrack.py @@ -0,0 +1,52 @@ +import numpy as np +from collections import OrderedDict + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + Removed = 3 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + score = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_removed(self): + self.state = TrackState.Removed \ No newline at end of file diff --git a/src/base_trackers/bytetrack/kalman_filter.py b/src/base_trackers/bytetrack/kalman_filter.py new file mode 100644 index 0000000..deda8a2 --- /dev/null +++ b/src/base_trackers/bytetrack/kalman_filter.py @@ -0,0 +1,270 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import scipy.linalg + + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, a, h, vx, vy, va, vh + + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, a, h) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """Create track from unassociated measurement. + + Parameters + ---------- + measurement : ndarray + Bounding box coordinates (x, y, a, h) with center position (x, y), + aspect ratio a, and height h. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are initialized + to 0 mean. + + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3]] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """Run Kalman filter prediction step. + + Parameters + ---------- + mean : ndarray + The 8 dimensional mean vector of the object state at the previous + time step. + covariance : ndarray + The 8x8 dimensional covariance matrix of the object state at the + previous time step. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + + """ + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3]] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3]] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + #mean = np.dot(self._motion_mat, mean) + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance): + """Project state distribution to measurement space. + + Parameters + ---------- + mean : ndarray + The state's mean vector (8 dimensional array). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + + Returns + ------- + (ndarray, ndarray) + Returns the projected mean and covariance matrix of the given state + estimate. + + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3]] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot(( + self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean, covariance): + """Run Kalman filter prediction step (Vectorized version). + Parameters + ---------- + mean : ndarray + The Nx8 dimensional mean matrix of the object states at the previous + time step. + covariance : ndarray + The Nx8x8 dimensional covariance matrics of the object states at the + previous time step. + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), + self._std_weight_position * mean[:, 3]] + std_vel = [ + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), + self._std_weight_velocity * mean[:, 3]] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [] + for i in range(len(mean)): + motion_cov.append(np.diag(sqr[i])) + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean, covariance, measurement): + """Run Kalman filter correction step. + + Parameters + ---------- + mean : ndarray + The predicted state's mean vector (8 dimensional). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + measurement : ndarray + The 4 dimensional measurement vector (x, y, a, h), where (x, y) + is the center position, a the aspect ratio, and h the height of the + bounding box. + + Returns + ------- + (ndarray, ndarray) + Returns the measurement-corrected state distribution. + + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot(( + kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, + only_position=False, metric='maha'): + """Compute gating distance between state distribution and measurements. + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + Parameters + ---------- + mean : ndarray + Mean vector over the state distribution (8 dimensional). + covariance : ndarray + Covariance of the state distribution (8x8 dimensional). + measurements : ndarray + An Nx4 dimensional matrix of N measurements, each in + format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position : Optional[bool] + If True, distance computation is done with respect to the bounding + box center position only. + Returns + ------- + ndarray + Returns an array of length N, where the i-th element contains the + squared Mahalanobis distance between (mean, covariance) and + `measurements[i]`. + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == 'gaussian': + return np.sum(d * d, axis=1) + elif metric == 'maha': + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular( + cholesky_factor, d.T, lower=True, check_finite=False, + overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha + else: + raise ValueError('invalid distance metric') \ No newline at end of file diff --git a/src/base_trackers/bytetrack/matching.py b/src/base_trackers/bytetrack/matching.py new file mode 100644 index 0000000..45d8fbd --- /dev/null +++ b/src/base_trackers/bytetrack/matching.py @@ -0,0 +1,186 @@ +import cv2 +import numpy as np +import scipy +import lap +from scipy.spatial.distance import cdist + +from cython_bbox import bbox_overlaps as bbox_ious +from . import kalman_filter + +def merge_matches(m1, m2, shape): + O,P,Q = shape + m1 = np.asarray(m1) + m2 = np.asarray(m2) + + M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) + M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) + + mask = M1*M2 + match = mask.nonzero() + match = list(zip(match[0], match[1])) + unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) + unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) + + return match, unmatched_O, unmatched_Q + + +def _indices_to_matches(cost_matrix, indices, thresh): + matched_cost = cost_matrix[tuple(zip(*indices))] + matched_mask = (matched_cost <= thresh) + + matches = indices[matched_mask] + unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def linear_assignment(cost_matrix, thresh): + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + matches, unmatched_a, unmatched_b = [], [], [] + cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + for ix, mx in enumerate(x): + if mx >= 0: + matches.append([ix, mx]) + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + matches = np.asarray(matches) + return matches, unmatched_a, unmatched_b + + +def ious(atlbrs, btlbrs): + """ + Compute cost based on IoU + :type atlbrs: list[tlbr] | np.ndarray + :type atlbrs: list[tlbr] | np.ndarray + + :rtype ious np.ndarray + """ + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float) + if ious.size == 0: + return ious + + ious = bbox_ious( + np.ascontiguousarray(atlbrs, dtype=np.float), + np.ascontiguousarray(btlbrs, dtype=np.float) + ) + + return ious + + +def iou_distance(atracks, btracks, consider_track_label=True): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlbr for track in atracks] + btlbrs = [track.tlbr for track in btracks] + _ious = ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + if consider_track_label: + a_labels = np.array([t.class_id for t in atracks]) + b_labels = np.array([t.class_id for t in btracks]) + label_match = (a_labels[:, None] == b_labels[None, :]) + cost_matrix = np.where(label_match, cost_matrix, 1) + + return cost_matrix + +def v_iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks] + btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks] + _ious = ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + +def embedding_distance(tracks, detections, metric='cosine'): + """ + :param tracks: list[STrack] + :param detections: list[BaseTrack] + :param metric: + :return: cost_matrix np.ndarray + """ + + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float) + #for i, track in enumerate(tracks): + #cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features + return cost_matrix + + +def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position) + cost_matrix[row, gating_distance > gating_threshold] = np.inf + return cost_matrix + + +def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position, metric='maha') + cost_matrix[row, gating_distance > gating_threshold] = np.inf + cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance + return cost_matrix + + +def fuse_iou(cost_matrix, tracks, detections): + if cost_matrix.size == 0: + return cost_matrix + reid_sim = 1 - cost_matrix + iou_dist = iou_distance(tracks, detections) + iou_sim = 1 - iou_dist + fuse_sim = reid_sim * (1 + iou_sim) / 2 + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + #fuse_sim = fuse_sim * (1 + det_scores) / 2 + fuse_cost = 1 - fuse_sim + return fuse_cost + + +def fuse_score(cost_matrix, detections): + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + fuse_cost = 1 - fuse_sim + return fuse_cost \ No newline at end of file diff --git a/src/base_trackers/iou_tracker.py b/src/base_trackers/iou_tracker.py new file mode 100644 index 0000000..4018613 --- /dev/null +++ b/src/base_trackers/iou_tracker.py @@ -0,0 +1,63 @@ +from .utils import iou_xywh +from .tracker import Tracker + + +class IOUTracker(Tracker): + """ + Intersection over Union Tracker. + + References + ---------- + * Implementation of this algorithm is heavily based on https://github.com/bochinski/iou-tracker + + Args: + max_lost (int): Maximum number of consecutive frames object was not detected. + tracker_output_format (str): Output format of the tracker. + min_detection_confidence (float): Threshold for minimum detection confidence. + max_detection_confidence (float): Threshold for max. detection confidence. + iou_threshold (float): Intersection over union minimum value. + """ + + def __init__( + self, + max_lost=2, + iou_threshold=0.5, + tracker_output_format='visdrone_challenge' + ): + self.iou_threshold = iou_threshold + + super(IOUTracker, self).__init__(max_lost=max_lost, tracker_output_format=tracker_output_format) + + def update(self, bboxes, detection_scores, class_ids): + detections = Tracker.preprocess_input(bboxes, class_ids, detection_scores) + self.frame_count += 1 + track_ids = list(self.tracks.keys()) + + updated_tracks = [] + for track_id in track_ids: + if len(detections) > 0: + idx, best_match = max(enumerate(detections), key=lambda x: self.iou_and_class_match(self.tracks[track_id], x[1])) + (bb, cid, scr) = best_match + + if self.iou_and_class_match(self.tracks[track_id], best_match) > self.iou_threshold: + self._update_track(track_id, self.frame_count, bb, scr, class_id=cid, + iou_score=self.iou_and_class_match(self.tracks[track_id], best_match)) + updated_tracks.append(track_id) + del detections[idx] + + if len(updated_tracks) == 0 or track_id is not updated_tracks[-1]: + self.tracks[track_id].lost += 1 + if self.tracks[track_id].lost > self.max_lost: + self._remove_track(track_id) + + for bb, cid, scr in detections: + self._add_track(self.frame_count, bb, scr, class_id=cid) + + outputs = self._get_tracks(self.tracks) + return outputs + + def iou_and_class_match(self, track, det): + if track.class_id == det[1]: + return iou_xywh(track.bbox, det[0]) + else: + return 0 diff --git a/src/base_trackers/sort_tracker.py b/src/base_trackers/sort_tracker.py new file mode 100644 index 0000000..7414b81 --- /dev/null +++ b/src/base_trackers/sort_tracker.py @@ -0,0 +1,269 @@ +""" + SORT: A Simple, Online and Realtime Tracker + Copyright (C) 2016-2020 Alex Bewley alex@bewley.ai + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +from __future__ import print_function + +import numpy as np +from filterpy.kalman import KalmanFilter +from .utils import box_xywh_to_xyxy +np.random.seed(0) + +def linear_assignment(cost_matrix): + try: + import lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i], i] for i in x if i >= 0]) # + except ImportError: + from scipy.optimize import linear_sum_assignment + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + + +def iou_batch(bb_test, bb_gt): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bb_gt = np.expand_dims(bb_gt, 0) + bb_test = np.expand_dims(bb_test, 1) + + xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0]) + yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1]) + xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2]) + yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1]) + + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh) + return (o) + + +def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2. + y = bbox[1] + h / 2. + s = w * h # scale is just area + r = w / float(h) + return np.array([x, y, s, r]).reshape((4, 1)) + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if (score == None): + return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4)) + else: + return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5)) + +class KalmanBoxTrack(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + """ + count = 0 + + def __init__(self, frame_id, bbox, detection_confidence, class_id=None): + """ + Initialises a tracker using initial bounding box. + """ + # define constant velocity model + self.kf = KalmanFilter(dim_x=7, dim_z=4) + self.kf.F = np.array( + [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]]) + self.kf.H = np.array( + [[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]]) + + self.kf.R[2:, 2:] *= 10. + self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10. + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[4:, 4:] *= 0.01 + + self.kf.x[:4] = convert_bbox_to_z(bbox) + self.time_since_update = 0 + KalmanBoxTrack.count += 1 + self.id = KalmanBoxTrack.count + self.history = [] + + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + self.frame_id = frame_id + self.detection_confidence = detection_confidence + self.class_id = class_id + + + def update(self, frame_id, bbox, detection_confidence, class_id=None): + """ + Updates the state vector with observed bbox. + """ + self.frame_id = frame_id + self.detection_confidence = detection_confidence + self.class_id = class_id + + self.time_since_update = 0 + self.history = [] + self.hits += 1 + self.hit_streak += 1 + self.kf.update(convert_bbox_to_z(bbox)) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if ((self.kf.x[6] + self.kf.x[2]) <= 0): + self.kf.x[6] *= 0.0 + self.kf.predict() + self.age += 1 + if (self.time_since_update > 0): + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(convert_x_to_bbox(self.kf.x)) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return convert_x_to_bbox(self.kf.x) + + def output(self): + """ + Track data output in VISDRONE Challenge format with tuple as + `(frame_index, target_id, bbox_left, bbox_top, bbox_width, bbox_height, score, object_category, + truncation, occlusion)`. + """ + bbox = self.get_state()[0] + mot_tuple = ( + self.frame_id, self.id, bbox[0], bbox[1], bbox[2]-bbox[0], bbox[3]-bbox[1], + self.detection_confidence, self.class_id, -1, -1 + ) + return mot_tuple + + +def associate_detections_to_tracks(detections, det_class_ids, tracks, iou_threshold=0.3): + """ + Assigns detections to tracked object (both represented as bounding boxes) + Returns 3 lists of matches, unmatched_detections and unmatched_trackers + """ + if (len(tracks) == 0): + return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 4), dtype=int) + + iou_matrix = iou_batch(detections, tracks) + iou_matrix = np.where(det_class_ids[:, None] == tracks[:, -1][None, :], iou_matrix, 0) # set label unmatched case as 0 + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-iou_matrix) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if (d not in matched_indices[:, 0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(tracks): + if (t not in matched_indices[:, 1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if (iou_matrix[m[0], m[1]] < iou_threshold): + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if (len(matches) == 0): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +class SORT(object): + def __init__(self, max_lost=2, min_hits=1, iou_threshold=0.3): + """ + Sets key parameters for SORT + """ + self.max_lost = max_lost + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.tracks = [] + self.frame_count = 0 + + def update(self, bboxes, detection_scores, class_ids): + """ + Params: + bboxes of xyxy + Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections). + Returns the a similar array, where the last column is the object ID. + NOTE: The number of objects returned may differ from the number of detections provided. + """ + self.frame_count += 1 + bboxes = box_xywh_to_xyxy(bboxes) + + # get predicted locations from existing trackers. + trks = np.zeros((len(self.tracks), 5)) + to_del = [] + for t, trk in enumerate(trks): + pos = self.tracks[t].predict()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], self.tracks[t].class_id] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.tracks.pop(t) + + matched, unmatched_dets, _ = associate_detections_to_tracks(bboxes, class_ids, trks, self.iou_threshold) + # update matched trackers with assigned detections + for m in matched: + self.tracks[m[1]].update(self.frame_count, bboxes[m[0]], detection_scores[m[0]], class_id=class_ids[m[0]]) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTrack(self.frame_count, bboxes[i], detection_scores[i], class_id=class_ids[i]) + self.tracks.append(trk) + + # check unmatched tracks + i = len(self.tracks) + for trk in reversed(self.tracks): + i -= 1 + if (trk.time_since_update > self.max_lost): + self.tracks.pop(i) + + outputs = self._get_tracks(self.tracks) + return outputs + + def _get_tracks(self, tracks): + outputs = [] + for track in tracks: + if (track.time_since_update < 1) and (track.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): + outputs.append(track.output()) + return outputs \ No newline at end of file diff --git a/src/base_trackers/track.py b/src/base_trackers/track.py new file mode 100644 index 0000000..7f685b0 --- /dev/null +++ b/src/base_trackers/track.py @@ -0,0 +1,147 @@ +import numpy as np + +class Track: + """ + Track containing attributes to track various objects. + + Args: + frame_id (int): Camera frame id. + track_id (int): Track Id + bbox (numpy.ndarray): Bounding box pixel coordinates as (xmin, ymin, width, height) of the track. + detection_confidence (float): Detection confidence of the object (probability). + class_id (str or int): Class label id. + lost (int): Number of times the object or track was not tracked by tracker in consecutive frames. + iou_score (float): Intersection over union score. + data_output_format (str): Output format for data in tracker. + Options include ``['mot_challenge', 'visdrone_challenge']``. Default is ``mot_challenge``. + kwargs (dict): Additional key word arguments. + + """ + + count = 0 + + metadata = dict( + data_output_formats=['mot_challenge', 'visdrone_challenge'] + ) + + def __init__( + self, + track_id, + frame_id, + bbox, + detection_confidence, + class_id=None, + lost=0, + iou_score=0., + data_output_format='visdrone_challenge', + **kwargs + ): + assert data_output_format in Track.metadata['data_output_formats'] + Track.count += 1 + self.id = track_id + + self.detection_confidence_max = 0. + self.lost = 0 + self.age = 0 + + self.update(frame_id, bbox, detection_confidence, class_id=class_id, lost=lost, iou_score=iou_score, **kwargs) + + if data_output_format == 'mot_challenge': + self.output = self.get_mot_challenge_format + elif data_output_format == 'visdrone_challenge': + self.output = self.get_vis_drone_format + else: + raise NotImplementedError + + def update(self, frame_id, bbox, detection_confidence, class_id=None, lost=0, iou_score=0., **kwargs): + """ + Update the track. + + Args: + frame_id (int): Camera frame id. + bbox (numpy.ndarray): Bounding box pixel coordinates as (xmin, ymin, width, height) of the track. + detection_confidence (float): Detection confidence of the object (probability). + class_id (int or str): Class label id. + lost (int): Number of times the object or track was not tracked by tracker in consecutive frames. + iou_score (float): Intersection over union score. + kwargs (dict): Additional key word arguments. + """ + self.class_id = class_id + self.bbox = np.array(bbox) + self.detection_confidence = detection_confidence + self.frame_id = frame_id + self.iou_score = iou_score + + if lost == 0: + self.lost = 0 + else: + self.lost += lost + + for k, v in kwargs.items(): + setattr(self, k, v) + + self.detection_confidence_max = max(self.detection_confidence_max, detection_confidence) + + self.age += 1 + + @property + def centroid(self): + """ + Return the centroid of the bounding box. + + Returns: + numpy.ndarray: Centroid (x, y) of bounding box. + + """ + return np.array((self.bbox[0]+0.5*self.bbox[2], self.bbox[1]+0.5*self.bbox[3])) + + def get_mot_challenge_format(self): + """ + Get the tracker data in MOT challenge format as a tuple of elements containing + `(frame, id, bb_left, bb_top, bb_width, bb_height, conf, x, y, z)` + + References: + - Website : https://motchallenge.net/ + + Returns: + tuple: Tuple of 10 elements representing `(frame, id, bb_left, bb_top, bb_width, bb_height, conf, x, y, z)`. + + """ + mot_tuple = ( + self.frame_id, self.id, self.bbox[0], self.bbox[1], self.bbox[2], self.bbox[3], self.detection_confidence, + -1, -1, -1 + ) + return mot_tuple + + def get_vis_drone_format(self): + """ + Track data output in VISDRONE Challenge format with tuple as + `(frame_index, target_id, bbox_left, bbox_top, bbox_width, bbox_height, score, object_category, + truncation, occlusion)`. + + References: + - Website : http://aiskyeye.com/ + - Paper : https://arxiv.org/abs/2001.06303 + - GitHub : https://github.com/VisDrone/VisDrone2018-MOT-toolkit + - GitHub : https://github.com/VisDrone/ + + Returns: + tuple: Tuple containing the elements as `(frame_index, target_id, bbox_left, bbox_top, bbox_width, bbox_height, + score, object_category, truncation, occlusion)`. + """ + mot_tuple = ( + self.frame_id, self.id, self.bbox[0], self.bbox[1], self.bbox[2], self.bbox[3], + self.detection_confidence, self.class_id, -1, -1 + ) + return mot_tuple + + def predict(self): + """ + Implement to prediction the next estimate of track. + """ + raise NotImplemented + + @staticmethod + def print_all_track_output_formats(): + print(Track.metadata['data_output_formats']) + diff --git a/src/base_trackers/tracker.py b/src/base_trackers/tracker.py new file mode 100644 index 0000000..62020da --- /dev/null +++ b/src/base_trackers/tracker.py @@ -0,0 +1,176 @@ +from collections import OrderedDict +import numpy as np +from scipy.spatial import distance +from .utils import get_centroid +from .track import Track + + +class Tracker: + """ + Greedy Tracker with tracking based on ``centroid`` location of the bounding box of the object. + This tracker is also referred as ``CentroidTracker`` in this repository. + + Args: + max_lost (int): Maximum number of consecutive frames object was not detected. + tracker_output_format (str): Output format of the tracker. + """ + + def __init__(self, max_lost=2, tracker_output_format='visdrone_challenge'): + self.next_track_id = 1 + self.tracks = OrderedDict() + self.max_lost = max_lost + self.frame_count = 0 + self.tracker_output_format = tracker_output_format + + def _add_track(self, frame_id, bbox, detection_confidence, class_id, **kwargs): + """ + Add a newly detected object to the queue. + + Args: + frame_id (int): Camera frame id. + bbox (numpy.ndarray): Bounding box pixel coordinates as (xmin, ymin, xmax, ymax) of the track. + detection_confidence (float): Detection confidence of the object (probability). + class_id (str or int): Class label id. + kwargs (dict): Additional key word arguments. + """ + + self.tracks[self.next_track_id] = Track( + self.next_track_id, frame_id, bbox, detection_confidence, class_id=class_id, + data_output_format=self.tracker_output_format, + **kwargs + ) + self.next_track_id += 1 + + def _remove_track(self, track_id): + """ + Remove tracker data after object is lost. + + Args: + track_id (int): track_id of the track lost while tracking. + """ + + del self.tracks[track_id] + + def _update_track(self, track_id, frame_id, bbox, detection_confidence, class_id, lost=0, iou_score=0., **kwargs): + """ + Update track state. + + Args: + track_id (int): ID of the track. + frame_id (int): Frame count. + bbox (numpy.ndarray or list): Bounding box coordinates as `(xmin, ymin, width, height)`. + detection_confidence (float): Detection confidence (a.k.a. detection probability). + class_id (int): ID of the class (aka label) of the object being tracked. + lost (int): Number of frames the object was lost while tracking. + iou_score (float): Intersection over union. + kwargs (dict): Additional keyword arguments. + """ + + self.tracks[track_id].update( + frame_id, bbox, detection_confidence, class_id=class_id, lost=lost, iou_score=iou_score, **kwargs + ) + + @staticmethod + def _get_tracks(tracks): + """ + Output the information of tracks. + + Args: + tracks (OrderedDict): Tracks dictionary with (key, value) as (track_id, corresponding `Track` objects). + + Returns: + list: List of tracks being currently tracked by the tracker. + """ + + outputs = [] + for trackid, track in tracks.items(): + if not track.lost: + outputs.append(track.output()) + return outputs + + @staticmethod + def preprocess_input(bboxes, class_ids, detection_scores): + """ + Preprocess the input data. + + Args: + bboxes (list or numpy.ndarray): Array of bounding boxes with each bbox as a tuple containing `(xmin, ymin, width, height)`. + class_ids (list or numpy.ndarray): Array of Class ID or label ID. + detection_scores (list or numpy.ndarray): Array of detection scores (a.k.a. detection probabilities). + + Returns: + detections (list[Tuple]): Data for detections as list of tuples containing `(bbox, class_id, detection_score)`. + """ + + new_bboxes = np.array(bboxes, dtype='float') + new_class_ids = np.array(class_ids, dtype='int') + new_detection_scores = np.array(detection_scores) + + new_detections = list(zip(new_bboxes, new_class_ids, new_detection_scores)) + return new_detections + + def update(self, bboxes, detection_scores, class_ids): + """ + Update the tracker based on the new bounding boxes. + + Args: + bboxes (numpy.ndarray or list): List of bounding boxes detected in the current frame. Each element of the list represent + coordinates of bounding box as tuple `(top-left-x, top-left-y, width, height)`. + detection_scores(numpy.ndarray or list): List of detection scores (probability) of each detected object. + class_ids (numpy.ndarray or list): List of class_ids (int) corresponding to labels of the detected object. Default is `None`. + + Returns: + list: List of tracks being currently tracked by the tracker. Each track is represented by the tuple with elements `(frame_id, track_id, bb_left, bb_top, bb_width, bb_height, conf, x, y, z)`. + """ + + self.frame_count += 1 + + if len(bboxes) == 0: + lost_ids = list(self.tracks.keys()) + + for track_id in lost_ids: + self.tracks[track_id].lost += 1 + if self.tracks[track_id].lost > self.max_lost: + self._remove_track(track_id) + + outputs = self._get_tracks(self.tracks) + return outputs + + detections = Tracker.preprocess_input(bboxes, class_ids, detection_scores) + + track_ids = list(self.tracks.keys()) + + updated_tracks, updated_detections = [], [] + + if len(track_ids): + track_centroids = np.array([self.tracks[tid].centroid for tid in track_ids]) + detection_centroids = get_centroid(np.asarray(bboxes)) + + centroid_distances = distance.cdist(track_centroids, detection_centroids) + + track_indices = np.amin(centroid_distances, axis=1).argsort() + + for idx in track_indices: + track_id = track_ids[idx] + + remaining_detections = [ + (i, d) for (i, d) in enumerate(centroid_distances[idx, :]) if i not in updated_detections] + + if len(remaining_detections): + detection_idx, detection_distance = min(remaining_detections, key=lambda x: x[1]) + bbox, class_id, confidence = detections[detection_idx] + self._update_track(track_id, self.frame_count, bbox, confidence, class_id=class_id) + updated_detections.append(detection_idx) + updated_tracks.append(track_id) + + if len(updated_tracks) == 0 or track_id is not updated_tracks[-1]: + self.tracks[track_id].lost += 1 + if self.tracks[track_id].lost > self.max_lost: + self._remove_track(track_id) + + for i, (bbox, class_id, confidence) in enumerate(detections): + if i not in updated_detections: + self._add_track(self.frame_count, bbox, confidence, class_id=class_id) + + outputs = self._get_tracks(self.tracks) + return outputs diff --git a/src/base_trackers/tracker_img.py b/src/base_trackers/tracker_img.py new file mode 100644 index 0000000..5b1300f --- /dev/null +++ b/src/base_trackers/tracker_img.py @@ -0,0 +1,63 @@ +import argparse +import time +import cv2 + + +ap = argparse.ArgumentParser() +ap.add_argument("-v", "--video", type=str, default='cars.mp4', help="path to input video file") +ap.add_argument("-t", "--tracker", type=str, default="kcf", help="OpenCV object tracker type") +args = vars(ap.parse_args()) + +OPENCV_OBJECT_TRACKERS = { + "csrt": cv2.TrackerCSRT_create, + "kcf": cv2.TrackerKCF_create, + "boosting": cv2.TrackerBoosting_create, + "mil": cv2.TrackerMIL_create, + "tld": cv2.TrackerTLD_create, + "medianflow": cv2.TrackerMedianFlow_create, + "mosse": cv2.TrackerMOSSE_create +} + +trackers = cv2.MultiTracker_create() + +vs = cv2.VideoCapture(args["video"]) + +while True: + ok, frame = vs.read() + if not ok: + break + + # resize the frame (so we can process it faster) + frame = cv2.resize(frame, (600, 400)) + + (success, boxes) = trackers.update(frame) + print(success) + + for box in boxes: + (x, y, w, h) = [int(v) for v in box] + cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) + + cv2.imshow("Frame", frame) + key = cv2.waitKey(1) & 0xFF + + # if the 's' key is selected, we are going to "select" a bounding + # box to tracks + if key == ord("s"): + # select the bounding box of the object we want to track (make + # sure you press ENTER or SPACE after selecting the ROI) + box = cv2.selectROI("Frame", frame, fromCenter=False, showCrosshair=True) + + # create a new object tracker for the bounding box and add it to our multi-object tracker + tracker = OPENCV_OBJECT_TRACKERS[args["tracker"]]() + trackers.add(tracker, frame, box) + + elif key == ord("q"): # if the `q` key was pressed, break from the loop + break + + time.sleep(0.1) + +# if we are using a webcam, release the pointer +vs.release() + +# close all windows +cv2.destroyAllWindows() diff --git a/src/base_trackers/utils.py b/src/base_trackers/utils.py new file mode 100644 index 0000000..bb3c9e7 --- /dev/null +++ b/src/base_trackers/utils.py @@ -0,0 +1,309 @@ +import numpy as np +import cv2 as cv + + +def get_centroid(bboxes): + """ + Calculate centroids for multiple bounding boxes. + + Args: + bboxes (numpy.ndarray): Array of shape `(n, 4)` or of shape `(4,)` where + each row contains `(xmin, ymin, width, height)`. + + Returns: + numpy.ndarray: Centroid (x, y) coordinates of shape `(n, 2)` or `(2,)`. + + """ + + one_bbox = False + if len(bboxes.shape) == 1: + one_bbox = True + bboxes = bboxes[None, :] + + xmin = bboxes[:, 0] + ymin = bboxes[:, 1] + w, h = bboxes[:, 2], bboxes[:, 3] + + xc = xmin + 0.5*w + yc = ymin + 0.5*h + + x = np.hstack([xc[:, None], yc[:, None]]) + + if one_bbox: + x = x.flatten() + return x + + +def iou(bbox1, bbox2): + """ + Calculates the intersection-over-union of two bounding boxes. + Source: https://github.com/bochinski/iou-tracker/blob/master/util.py + + Args: + bbox1 (numpy.array or list[floats]): Bounding box of length 4 containing + ``(x-top-left, y-top-left, x-bottom-right, y-bottom-right)``. + bbox2 (numpy.array or list[floats]): Bounding box of length 4 containing + ``(x-top-left, y-top-left, x-bottom-right, y-bottom-right)``. + + Returns: + float: intersection-over-onion of bbox1, bbox2. + """ + + bbox1 = [float(x) for x in bbox1] + bbox2 = [float(x) for x in bbox2] + + (x0_1, y0_1, x1_1, y1_1), (x0_2, y0_2, x1_2, y1_2) = bbox1, bbox2 + + # get the overlap rectangle + overlap_x0 = max(x0_1, x0_2) + overlap_y0 = max(y0_1, y0_2) + overlap_x1 = min(x1_1, x1_2) + overlap_y1 = min(y1_1, y1_2) + + # check if there is an overlap + if overlap_x1 - overlap_x0 <= 0 or overlap_y1 - overlap_y0 <= 0: + return 0.0 + + # if yes, calculate the ratio of the overlap to each ROI size and the unified size + size_1 = (x1_1 - x0_1) * (y1_1 - y0_1) + size_2 = (x1_2 - x0_2) * (y1_2 - y0_2) + size_intersection = (overlap_x1 - overlap_x0) * (overlap_y1 - overlap_y0) + size_union = size_1 + size_2 - size_intersection + + iou_ = size_intersection / size_union + + return iou_ + + +def iou_xywh(bbox1, bbox2): + """ + Calculates the intersection-over-union of two bounding boxes. + Source: https://github.com/bochinski/iou-tracker/blob/master/util.py + + Args: + bbox1 (numpy.array or list[floats]): bounding box of length 4 containing ``(x-top-left, y-top-left, width, height)``. + bbox2 (numpy.array or list[floats]): bounding box of length 4 containing ``(x-top-left, y-top-left, width, height)``. + + Returns: + float: intersection-over-onion of bbox1, bbox2. + """ + bbox1 = bbox1[0], bbox1[1], bbox1[0]+bbox1[2], bbox1[1]+bbox1[3] + bbox2 = bbox2[0], bbox2[1], bbox2[0]+bbox2[2], bbox2[1]+bbox2[3] + + iou_ = iou(bbox1, bbox2) + + return iou_ + + +def xyxy2xywh(xyxy): + """ + Convert bounding box coordinates from (xmin, ymin, xmax, ymax) format to (xmin, ymin, width, height). + + Args: + xyxy (numpy.ndarray): + + Returns: + numpy.ndarray: Bounding box coordinates (xmin, ymin, width, height). + + """ + + if len(xyxy.shape) == 2: + w, h = xyxy[:, 2] - xyxy[:, 0] + 1, xyxy[:, 3] - xyxy[:, 1] + 1 + xywh = np.concatenate((xyxy[:, 0:2], w[:, None], h[:, None]), axis=1) + return xywh.astype("int") + elif len(xyxy.shape) == 1: + (left, top, right, bottom) = xyxy + width = right - left + 1 + height = bottom - top + 1 + return np.array([left, top, width, height]).astype('int') + else: + raise ValueError("Input shape not compatible.") + +def box_xywh_to_xyxy(xywh_box): + x0, y0, w, h = xywh_box[::, 0], xywh_box[::, 1], xywh_box[::, 2], xywh_box[::, 3] + b = [x0, y0, (x0 + w), (y0 + h)] + return np.column_stack(b) + +def xywh2xyxy(xywh): + """ + Convert bounding box coordinates from (xmin, ymin, width, height) to (xmin, ymin, xmax, ymax) format. + + Args: + xywh (numpy.ndarray): Bounding box coordinates as `(xmin, ymin, width, height)`. + + Returns: + numpy.ndarray : Bounding box coordinates as `(xmin, ymin, xmax, ymax)`. + + """ + + if len(xywh.shape) == 2: + x = xywh[:, 0] + xywh[:, 2] + y = xywh[:, 1] + xywh[:, 3] + xyxy = np.concatenate((xywh[:, 0:2], x[:, None], y[:, None]), axis=1).astype('int') + return xyxy + if len(xywh.shape) == 1: + x, y, w, h = xywh + xr = x + w + yb = y + h + return np.array([x, y, xr, yb]).astype('int') + + +def midwh2xywh(midwh): + """ + Convert bounding box coordinates from (xmid, ymid, width, height) to (xmin, ymin, width, height) format. + + Args: + midwh (numpy.ndarray): Bounding box coordinates (xmid, ymid, width, height). + + Returns: + numpy.ndarray: Bounding box coordinates (xmin, ymin, width, height). + """ + + if len(midwh.shape) == 2: + xymin = midwh[:, 0:2] - midwh[:, 2:] * 0.5 + wh = midwh[:, 2:] + xywh = np.concatenate([xymin, wh], axis=1).astype('int') + return xywh + if len(midwh.shape) == 1: + xmid, ymid, w, h = midwh + xywh = np.array([xmid-w*0.5, ymid-h*0.5, w, h]).astype('int') + return xywh + + +def intersection_complement_indices(big_set_indices, small_set_indices): + """ + Get the complement of intersection of two sets of indices. + + Args: + big_set_indices (numpy.ndarray): Indices of big set. + small_set_indices (numpy.ndarray): Indices of small set. + + Returns: + numpy.ndarray: Indices of set which is complementary to intersection of two input sets. + """ + assert big_set_indices.shape[0] >= small_set_indices.shape[1] + n = len(big_set_indices) + mask = np.ones((n,), dtype=bool) + mask[small_set_indices] = False + intersection_complement = big_set_indices[mask] + return intersection_complement + + +def nms(boxes, scores, overlapThresh, classes=None): + """ + Non-maximum suppression. based on Malisiewicz et al. + + Args: + boxes (numpy.ndarray): Boxes to process (xmin, ymin, xmax, ymax) + scores (numpy.ndarray): Corresponding scores for each box + overlapThresh (float): Overlap threshold for boxes to merge + classes (numpy.ndarray, optional): Class ids for each box. + + Returns: + tuple: a tuple containing: + - boxes (list): nms boxes + - scores (list): nms scores + - classes (list, optional): nms classes if specified + + """ + + if boxes.dtype.kind == "i": + boxes = boxes.astype("float") + + if scores.dtype.kind == "i": + scores = scores.astype("float") + + pick = [] + + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = (x2 - x1 + 1) * (y2 - y1 + 1) + + idxs = np.argsort(scores) + + while len(idxs) > 0: + last = len(idxs) - 1 + i = idxs[last] + pick.append(i) + + xx1 = np.maximum(x1[i], x1[idxs[:last]]) + yy1 = np.maximum(y1[i], y1[idxs[:last]]) + xx2 = np.minimum(x2[i], x2[idxs[:last]]) + yy2 = np.minimum(y2[i], y2[idxs[:last]]) + + w = np.maximum(0, xx2 - xx1 + 1) + h = np.maximum(0, yy2 - yy1 + 1) + + overlap = (w * h) / area[idxs[:last]] + + # delete all indexes from the index list that have + idxs = np.delete(idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))) + + if classes is not None: + return boxes[pick], scores[pick], classes[pick] + else: + return boxes[pick], scores[pick] + + +def draw_tracks(image, tracks): + """ + Draw on input image. + + Args: + image (numpy.ndarray): image + tracks (list): list of tracks to be drawn on the image. + + Returns: + numpy.ndarray: image with the track-ids drawn on it. + """ + + for trk in tracks: + + trk_id = trk[1] + xmin = trk[2] + ymin = trk[3] + width = trk[4] + height = trk[5] + + xcentroid, ycentroid = int(xmin + 0.5*width), int(ymin + 0.5*height) + + text = "ID {}".format(trk_id) + + cv.putText(image, text, (xcentroid - 10, ycentroid - 10), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) + cv.circle(image, (xcentroid, ycentroid), 4, (0, 255, 0), -1) + + return image + + +def load_labelsjson(json_file): + import json + with open(json_file) as file: + data = json.load(file) + labels = {int(k): v for k, v in data.items()} + return labels + + +def dict2jsonfile(dict_data, json_file_path): + import json + with open(json_file_path, 'w') as outfile: + json.dump(dict_data, outfile) + + +if __name__ == '__main__': + bb = np.random.random_integers(0, 100, size=(20,)).reshape((5, 4)) + c = get_centroid(bb) + print(bb, c) + + bb2 = np.array([1, 2, 3, 4]) + c2 = get_centroid(bb2) + print(bb2, c2) + + data = { + 0: 'background', 1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat', 5: 'bottle', 6: 'bus', + 7: 'car', 8: 'cat', 9: 'chair', 10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse', 14: 'motorbike', + 15: 'person', 16: 'pottedplant', 17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor' + } + dict2jsonfile(data, '../../examples/pretrained_models/caffemodel_weights/ssd_mobilenet_caffe_names.json') + diff --git a/src/combine_frames.py b/src/combine_frames.py new file mode 100644 index 0000000..327e212 --- /dev/null +++ b/src/combine_frames.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Combine two sets of frames to one. +""" +import os +import os.path as osp + +from PIL import Image + +OUTPUT_DIR = 'models/mot17_masks_track_rcnn_and_v3_combined' + +FRAME_DIR_1 = 'models/mot17_masks_track_rcnn/MOTS20-TEST' +FRAME_DIR_2 = 'models/mot17_masks_v3/MOTS20-ALL' + + +if __name__ == '__main__': + seqs_1 = os.listdir(FRAME_DIR_1) + seqs_2 = os.listdir(FRAME_DIR_2) + + if not osp.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR) + + for seq in seqs_1: + if seq in seqs_2: + print(seq) + seg_output_dir = osp.join(OUTPUT_DIR, seq) + if not osp.exists(seg_output_dir): + os.makedirs(seg_output_dir) + + frames = os.listdir(osp.join(FRAME_DIR_1, seq)) + + for frame in frames: + img_1 = Image.open(osp.join(FRAME_DIR_1, seq, frame)) + img_2 = Image.open(osp.join(FRAME_DIR_2, seq, frame)) + + width = img_1.size[0] + height = img_2.size[1] + + combined_frame = Image.new('RGB', (width, height * 2)) + combined_frame.paste(img_1, (0, 0)) + combined_frame.paste(img_2, (0, height)) + + combined_frame.save(osp.join(seg_output_dir, f'{frame}')) diff --git a/src/compute_best_mean_epoch_from_splits.py b/src/compute_best_mean_epoch_from_splits.py new file mode 100644 index 0000000..67c749b --- /dev/null +++ b/src/compute_best_mean_epoch_from_splits.py @@ -0,0 +1,232 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import os +import json +import numpy as np + + +LOG_DIR = 'logs/visdom' + +METRICS = ['MOTA', 'IDF1', 'BBOX AP IoU=0.50:0.95', 'MASK AP IoU=0.50:0.95'] + +RUNS = [ + 'mot17_train_1_deformable_full_res', + 'mot17_train_2_deformable_full_res', + 'mot17_train_3_deformable_full_res', + 'mot17_train_4_deformable_full_res', + 'mot17_train_5_deformable_full_res', + 'mot17_train_6_deformable_full_res', + 'mot17_train_7_deformable_full_res', + ] + +RUNS = [ + 'mot17_train_1_no_pretrain_deformable_tracking', + 'mot17_train_2_no_pretrain_deformable_tracking', + 'mot17_train_3_no_pretrain_deformable_tracking', + 'mot17_train_4_no_pretrain_deformable_tracking', + 'mot17_train_5_no_pretrain_deformable_tracking', + 'mot17_train_6_no_pretrain_deformable_tracking', + 'mot17_train_7_no_pretrain_deformable_tracking', + ] + +RUNS = [ + 'mot17_train_1_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_2_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_3_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_4_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_5_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_6_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_7_coco_pretrain_deformable_tracking_lr=0.00001', + ] + +RUNS = [ + 'mot17_train_1_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_2_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_3_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_4_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_5_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_6_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', + 'mot17_train_7_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', + ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_tracking_eos_coef=0.2', +# 'mot17_train_2_no_pretrain_deformable_tracking_eos_coef=0.2', +# 'mot17_train_3_no_pretrain_deformable_tracking_eos_coef=0.2', +# 'mot17_train_4_no_pretrain_deformable_tracking_eos_coef=0.2', +# 'mot17_train_5_no_pretrain_deformable_tracking_eos_coef=0.2', +# 'mot17_train_6_no_pretrain_deformable_tracking_eos_coef=0.2', +# 'mot17_train_7_no_pretrain_deformable_tracking_eos_coef=0.2', +# ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_tracking_lr_drop=50', +# 'mot17_train_2_no_pretrain_deformable_tracking_lr_drop=50', +# 'mot17_train_3_no_pretrain_deformable_tracking_lr_drop=50', +# 'mot17_train_4_no_pretrain_deformable_tracking_lr_drop=50', +# 'mot17_train_5_no_pretrain_deformable_tracking_lr_drop=50', +# 'mot17_train_6_no_pretrain_deformable_tracking_lr_drop=50', +# 'mot17_train_7_no_pretrain_deformable_tracking_lr_drop=50', +# ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_tracking_save_model_interval=1', +# 'mot17_train_2_no_pretrain_deformable_tracking_save_model_interval=1', +# 'mot17_train_3_no_pretrain_deformable_tracking_save_model_interval=1', +# 'mot17_train_4_no_pretrain_deformable_tracking_save_model_interval=1', +# 'mot17_train_5_no_pretrain_deformable_tracking_save_model_interval=1', +# 'mot17_train_6_no_pretrain_deformable_tracking_save_model_interval=1', +# 'mot17_train_7_no_pretrain_deformable_tracking_save_model_interval=1', +# ] + +# RUNS = [ + # 'mot17_train_1_no_pretrain_deformable_tracking_save_model_interval=1', + # 'mot17_train_2_no_pretrain_deformable_tracking_save_model_interval=1', + # 'mot17_train_3_no_pretrain_deformable_tracking_save_model_interval=1', + # 'mot17_train_4_no_pretrain_deformable_tracking_save_model_interval=1', + # 'mot17_train_5_no_pretrain_deformable_tracking_save_model_interval=1', + # 'mot17_train_6_no_pretrain_deformable_tracking_save_model_interval=1', + # 'mot17_train_7_no_pretrain_deformable_tracking_save_model_interval=1', + # ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_full_res', +# 'mot17_train_2_no_pretrain_deformable_full_res', +# 'mot17_train_3_no_pretrain_deformable_full_res', +# 'mot17_train_4_no_pretrain_deformable_full_res', +# 'mot17_train_5_no_pretrain_deformable_full_res', +# 'mot17_train_6_no_pretrain_deformable_full_res', +# 'mot17_train_7_no_pretrain_deformable_full_res', +# ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', +# 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', +# 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', +# 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', +# 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', +# 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', +# 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', +# ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', +# 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', +# 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', +# 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', +# 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', +# 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', +# 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', +# ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', +# 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', +# 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', +# 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', +# 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', +# 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', +# 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', +# ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', +# 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', +# 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', +# 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', +# 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', +# 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', +# 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', +# ] + +# RUNS = [ +# 'mot17_train_1_no_pretrain_deformable', +# 'mot17_train_2_no_pretrain_deformable', +# 'mot17_train_3_no_pretrain_deformable', +# 'mot17_train_4_no_pretrain_deformable', +# 'mot17_train_5_no_pretrain_deformable', +# 'mot17_train_6_no_pretrain_deformable', +# 'mot17_train_7_no_pretrain_deformable', +# ] + +# +# MOTS 4-fold split +# + +# RUNS = [ +# 'mots20_train_1_coco_tracking', +# 'mots20_train_2_coco_tracking', +# 'mots20_train_3_coco_tracking', +# 'mots20_train_4_coco_tracking', +# ] + +# RUNS = [ +# 'mots20_train_1_coco_tracking_full_res_masks=False', +# 'mots20_train_2_coco_tracking_full_res_masks=False', +# 'mots20_train_3_coco_tracking_full_res_masks=False', +# 'mots20_train_4_coco_tracking_full_res_masks=False', +# ] + +# RUNS = [ +# 'mots20_train_1_coco_full_res_pretrain_masks=False_lr_0_0001', +# 'mots20_train_2_coco_full_res_pretrain_masks=False_lr_0_0001', +# 'mots20_train_3_coco_full_res_pretrain_masks=False_lr_0_0001', +# 'mots20_train_4_coco_full_res_pretrain_masks=False_lr_0_0001', +# ] + +# RUNS = [ +# 'mots20_train_1_coco_tracking_full_res_masks=False_pretrain', +# 'mots20_train_2_coco_tracking_full_res_masks=False_pretrain', +# 'mots20_train_3_coco_tracking_full_res_masks=False_pretrain', +# 'mots20_train_4_coco_tracking_full_res_masks=False_pretrain', +# ] + +# RUNS = [ +# 'mot17det_train_1_mots_track_bbox_proposals_pretrain_train_1_mots_vis_save_model_interval_1', +# 'mot17det_train_2_mots_track_bbox_proposals_pretrain_train_3_mots_vis_save_model_interval_1', +# 'mot17det_train_3_mots_track_bbox_proposals_pretrain_train_4_mots_vis_save_model_interval_1', +# 'mot17det_train_4_mots_track_bbox_proposals_pretrain_train_6_mots_vis_save_model_interval_1', +# ] + +if __name__ == '__main__': + results = {} + + for r in RUNS: + print(r) + log_file = os.path.join(LOG_DIR, f"{r}.json") + + with open(log_file) as json_file: + data = json.load(json_file) + + window = [ + window for window in data['jsons'].values() + if window['title'] == 'VAL EVAL EPOCHS'][0] + + for m in METRICS: + if m not in window['legend']: + continue + elif m not in results: + results[m] = [] + + idxs = window['legend'].index(m) + + values = window['content']['data'][idxs]['y'] + results[m].append(values) + + print(f'NUM EPOCHS: {len(values)}') + + min_length = min([len(l) for l in next(iter(results.values()))]) + + for metric in results.keys(): + results[metric] = [l[:min_length] for l in results[metric]] + + mean_results = { + metric: np.array(results[metric]).mean(axis=0) + for metric in results.keys()} + + print("* METRIC INTERVAL = BEST EPOCHS") + for metric in results.keys(): + best_interval = mean_results[metric].argmax() + print(mean_results[metric]) + print( + f'{metric}: {mean_results[metric].max():.2%} at {best_interval + 1}/{len(mean_results[metric])} ' + f'{[(mmetric, f"{mean_results[mmetric][best_interval]:.2%}") for mmetric in results.keys() if not mmetric == metric]}') diff --git a/src/generate_coco_from_actiongenome.py b/src/generate_coco_from_actiongenome.py new file mode 100644 index 0000000..894356d --- /dev/null +++ b/src/generate_coco_from_actiongenome.py @@ -0,0 +1,214 @@ +import json +import os +import pickle +import torch +import numpy as np + +DATA_ROOT = 'data/ActionGenome' + +def save_annotations(split_name, filter_nonperson_box_frame=True, video_limit=None): + assert split_name in ['train', 'test'] + + with open(DATA_ROOT + '/annotations/person_bbox.pkl', 'rb') as f: + person_bbox = pickle.load(f) + with open(DATA_ROOT + '/annotations/object_bbox_and_relationship_filtersmall.pkl', 'rb') as f: # follow STTran + # with open(DATA_ROOT + '/annotations/object_bbox_and_relationship.pkl', 'rb') as f: + object_bbox = pickle.load(f) + + # # check image exist + # if split_name == 'train': + # non_exist_frames = [] + # for frame_key in person_bbox.keys(): + # if not os.path.isfile(f"{DATA_ROOT}/frames/{frame_key}"): + # print(f'{frame_key}') + # non_exist_frames.append(frame_key) + # print(non_exist_frames) + # assert len(non_exist_frames) == 0 + + # collect valid frames + video_dict = {} + for frame_key in person_bbox.keys(): + if object_bbox[frame_key][0]['metadata']['set'] == split_name: + frame_valid = False + for j in object_bbox[frame_key]: # the frame is valid if there is visible bbox + if j['visible']: frame_valid = True + + if frame_valid: + video_name, frame_num = frame_key.split('/') + if video_name in video_dict.keys(): + video_dict[video_name].append(frame_key) + else: + video_dict[video_name] = [frame_key] + + # get annotations + video_level_annotations, video_list, video_size = {}, [], [] + non_gt_human_nums, valid_nums = 0, 0 + one_frame_video, one_frame_video, non_person_video = 0, 0, 0 + annotation_id, image_id = 0, 0 + for i in video_dict.keys(): + video = [] + gt_annotation_video = {} + for j in sorted(video_dict[i]): + if filter_nonperson_box_frame: + if person_bbox[j]['bbox'].shape[0] == 0: + non_gt_human_nums += 1 + continue + else: + video.append(j) + valid_nums += 1 + + # person box # 查看数据集,test split一帧最多1个human + frame_person_box = [float(x) for x in person_bbox[j]['bbox'][0]] + gt_annotation_frame = [ + { + "id": annotation_id, + "bbox": [frame_person_box[0], frame_person_box[1], max(frame_person_box[2]-frame_person_box[0], 0), max(frame_person_box[3]-frame_person_box[1], 0)], + "image_id": image_id, + "segmentation": [], + "ignore": False, + "visibility": True, + "area": (frame_person_box[2]-frame_person_box[0]) * (frame_person_box[3]-frame_person_box[1]), + "iscrowd": 0, + "seq": i, + "category_id": object_classes.index('person'), + "track_id": 0, + } + ] + annotation_id += 1 + + # non-human objects + for k in object_bbox[j]: + if k['visible']: + gt_annotation_frame.append({ + "id": annotation_id, + "bbox": k['bbox'], + "image_id": image_id, + "segmentation": [], + "ignore": False, + "visibility": True, + "area": k['bbox'][2] * k['bbox'][3], + "iscrowd": 0, + "seq": i, + "category_id": object_classes.index(k['class']), + "track_id": object_classes.index(k['class']), # label as track_id, since only one intance of the same class exists in ActionGenome dataset + # "attention_relationship": [attention_relationships.index(r) for r in k['attention_relationship']], + # "spatial_relationship": [spatial_relationships.index(r) for r in k['spatial_relationship']], + # "contacting_relationship": [contacting_relationships.index(r) for r in k['contacting_relationship']], + "relationships": [relationship_classes.index(r) for r in k['attention_relationship']+k['spatial_relationship']+k['contacting_relationship']] + }) + annotation_id += 1 + + image_id += 1 + gt_annotation_video[j] = gt_annotation_frame + + if len(video) > 2: + video_list.append(video) + video_size.append(person_bbox[j]['bbox_size']) + video_level_annotations[i]= gt_annotation_video + elif len(video) == 1: + one_frame_video += 1 + else: + non_person_video += 1 + + print('x'*60) + print('There are {} videos and {} valid frames'.format(len(video_list), valid_nums)) + print('\t{} videos are invalid (no person), remove them'.format(non_person_video)) + print('\t{} videos are invalid (only one frame), remove them'.format(one_frame_video)) + print('\t{} frames have no human bbox in GT, remove them!'.format(non_gt_human_nums)) + print('x' * 60) + + # to COCO format + seqs = sorted(list(video_level_annotations.keys())) + if video_limit is not None: seqs=seqs[:video_limit] + annotations_coco_format = { + 'type': 'instances', + 'categories': [{'id': id, 'name': c, 'supercategory': c} for id, c in enumerate(object_classes)], + 'images': [], + 'annotations': [], + 'sequences': seqs, + 'sequence_startend_image_ids': [] + } + for vk in seqs: + video_info = video_level_annotations[vk] + + vframe_image_ids = [video_info[fkey][0]['image_id'] for fkey in sorted(video_info.keys())] + annotations_coco_format['sequence_startend_image_ids'].append((min(vframe_image_ids), max(vframe_image_ids))) + + # https://zhuanlan.zhihu.com/p/29393415 + for fid, fkey in enumerate(sorted(video_info.keys())): + annotations_coco_format['images'].append({ + 'id': video_info[fkey][0]['image_id'], + 'file_name': fkey, + 'frame_id': fid, + 'first_frame_image_id': min(vframe_image_ids) + }) + annotations_coco_format['annotations'].extend(video_info[fkey]) + + annotation_file = f'{DATA_ROOT}/{split_name}_cocofmt.json' + if video_limit is not None: annotation_file = f'{DATA_ROOT}/{split_name}_v{video_limit}_cocofmt.json' + with open(annotation_file, 'w') as anno_file: + json.dump(annotations_coco_format, anno_file, indent=4) + print(f'Saved {split_name} annotaions to {annotation_file}') + +if __name__ == '__main__': + ## save meta infos + save_path=f'{DATA_ROOT}/meta_infos.json' + + # collect the object classes + object_classes = [] + object_classes.append('__background__') + with open(os.path.join(DATA_ROOT, 'annotations/object_classes.txt'), 'r') as f: + for line in f.readlines(): + line = line.strip('\n') + object_classes.append(line) + object_classes[9] = 'closet/cabinet' + object_classes[11] = 'cup/glass/bottle' + object_classes[23] = 'paper/notebook' + object_classes[24] = 'phone/camera' + object_classes[31] = 'sofa/couch' + + # collect relationship classes + relationship_classes = [] + with open(os.path.join(DATA_ROOT, 'annotations/relationship_classes.txt'), 'r') as f: + for line in f.readlines(): + line = line.strip('\n') + relationship_classes.append(line) + relationship_classes[0] = 'looking_at' + relationship_classes[1] = 'not_looking_at' + relationship_classes[5] = 'in_front_of' + relationship_classes[7] = 'on_the_side_of' + relationship_classes[10] = 'covered_by' + relationship_classes[11] = 'drinking_from' + relationship_classes[13] = 'have_it_on_the_back' + relationship_classes[15] = 'leaning_on' + relationship_classes[16] = 'lying_on' + relationship_classes[17] = 'not_contacting' + relationship_classes[18] = 'other_relationship' + relationship_classes[19] = 'sitting_on' + relationship_classes[20] = 'standing_on' + relationship_classes[25] = 'writing_on' + + attention_relationships = relationship_classes[0:3] + spatial_relationships = relationship_classes[3:9] + contacting_relationships = relationship_classes[9:] + + # save to json + meta_infos = { + 'object_classes': object_classes, + 'relationship_classes': relationship_classes, + 'attention_relationships': attention_relationships, + 'spatial_relationships': spatial_relationships, + 'contacting_relationships': contacting_relationships + } + + with open(save_path, 'w') as f: + json.dump(meta_infos, f, indent=4) + print(f'Saved meta infos to {save_path}') + + # save annotation in COCO format + save_annotations('train') + save_annotations('train', video_limit=1000) + + save_annotations('test') + save_annotations('test', video_limit=500) + save_annotations('test', video_limit=200) diff --git a/src/generate_coco_from_crowdhuman.py b/src/generate_coco_from_crowdhuman.py new file mode 100644 index 0000000..b60bb75 --- /dev/null +++ b/src/generate_coco_from_crowdhuman.py @@ -0,0 +1,117 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Generates COCO data and annotation structure from CrowdHuman data. +""" +import json +import os + +from generate_coco_from_mot import check_coco_from_mot + +DATA_ROOT = 'data/CrowdHuman' +VIS_THRESHOLD = 0.0 + + +def generate_coco_from_crowdhuman(): + """ + Generate COCO data from CrowdHuman. + """ + split_name = 'train_val' + split = 'train_val' + + annotations = {} + annotations['type'] = 'instances' + annotations['images'] = [] + annotations['categories'] = [{"supercategory": "pedestrian", + "name": "pedestrian", + "id": 1}] + annotations['annotations'] = [] + annotation_file = os.path.join(DATA_ROOT, f'annotations/{split_name}.json') + + # IMAGES + imgs_list_dir = os.listdir(os.path.join(DATA_ROOT, split)) + for i, img in enumerate(sorted(imgs_list_dir)): + annotations['images'].append({"file_name": img, "id": i, }) + + # GT + annotation_id = 0 + img_file_name_to_id = { + os.path.splitext(img_dict['file_name'])[0]: img_dict['id'] + for img_dict in annotations['images']} + + for split in ['train', 'val']: + odgt_annos_file = os.path.join(DATA_ROOT, f'annotations/annotation_{split}.odgt') + with open(odgt_annos_file, 'r+') as anno_file: + datalist = anno_file.readlines() + + ignores = 0 + for data in datalist: + json_data = json.loads(data) + gtboxes = json_data['gtboxes'] + for gtbox in gtboxes: + if gtbox['tag'] == 'person': + bbox = gtbox['fbox'] + area = bbox[2] * bbox[3] + + ignore = False + visibility = 1.0 + # if 'occ' in gtbox['extra']: + # visibility = 1.0 - gtbox['extra']['occ'] + # if visibility <= VIS_THRESHOLD: + # ignore = True + + if 'ignore' in gtbox['extra']: + ignore = ignore or bool(gtbox['extra']['ignore']) + + ignores += int(ignore) + + annotation = { + "id": annotation_id, + "bbox": bbox, + "image_id": img_file_name_to_id[json_data['ID']], + "segmentation": [], + "ignore": int(ignore), + "visibility": visibility, + "area": area, + "iscrowd": 0, + "category_id": annotations['categories'][0]['id'],} + + annotation_id += 1 + annotations['annotations'].append(annotation) + + # max objs per image + num_objs_per_image = {} + for anno in annotations['annotations']: + image_id = anno["image_id"] + if image_id in num_objs_per_image: + num_objs_per_image[image_id] += 1 + else: + num_objs_per_image[image_id] = 1 + + print(f'max objs per image: {max([n for n in num_objs_per_image.values()])}') + print(f'ignore augs: {ignores}/{len(annotations["annotations"])}') + print(len(annotations['images'])) + + ignore_img_ids = [] + for img_id, num_objs in num_objs_per_image.items(): + if num_objs > 50 or num_objs < 2: + ignore_img_ids.append(img_id) + + annotations['images'] = [ + img for img in annotations['images'] + if img['id'] not in ignore_img_ids] + + annotations['annotations'] = [ + anno for anno in annotations['annotations'] + if anno['image_id'] not in ignore_img_ids] + + print(ignore_img_ids) + print(len(annotations['images'])) + + with open(annotation_file, 'w') as anno_file: + json.dump(annotations, anno_file, indent=4) + + +if __name__ == '__main__': + generate_coco_from_crowdhuman() + + # check_coco_from_mot('train') diff --git a/src/generate_coco_from_mot.py b/src/generate_coco_from_mot.py new file mode 100644 index 0000000..f43a1db --- /dev/null +++ b/src/generate_coco_from_mot.py @@ -0,0 +1,388 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Generates COCO data and annotation structure from MOTChallenge data. +""" +import argparse +import configparser +import csv +import json +import os +import shutil + +import numpy as np +import pycocotools.mask as rletools +import skimage.io as io +import torch +from matplotlib import pyplot as plt +from pycocotools.coco import COCO +from scipy.optimize import linear_sum_assignment +from torchvision.ops.boxes import box_iou + +from trackformer.datasets.tracking.mots20_sequence import load_mots_gt + +DATA_ROOT = 'data/MOT17' +MOTS_ROOT = 'data/MOTS20' +VIS_THRESHOLD = 0.0 + +MOT_15_SEQS_INFO = { + 'ETH-Bahnhof': {'img_width': 640, 'img_height': 480, 'seq_length': 1000}, + 'ETH-Sunnyday': {'img_width': 640, 'img_height': 480, 'seq_length': 354}, + 'KITTI-13': {'img_width': 1242, 'img_height': 375, 'seq_length': 340}, + 'KITTI-17': {'img_width': 1224, 'img_height': 370, 'seq_length': 145}, + 'PETS09-S2L1': {'img_width': 768, 'img_height': 576, 'seq_length': 795}, + 'TUD-Campus': {'img_width': 640, 'img_height': 480, 'seq_length': 71}, + 'TUD-Stadtmitte': {'img_width': 640, 'img_height': 480, 'seq_length': 179},} + + +def generate_coco_from_mot(split_name='train', seqs_names=None, + root_split='train', mots=False, mots_vis=False, + frame_range=None): + """ + Generates COCO data from MOT. + """ + global DATA_ROOT + + if frame_range is None: + frame_range = {'start': 0.0, 'end': 1.0} + + if mots: + DATA_ROOT = MOTS_ROOT + root_split_path = os.path.join(DATA_ROOT, root_split) + root_split_mots_path = os.path.join(MOTS_ROOT, root_split) + coco_dir = os.path.join(DATA_ROOT, split_name) + + if os.path.isdir(coco_dir): + shutil.rmtree(coco_dir) + + os.mkdir(coco_dir) + + annotations = {} + annotations['type'] = 'instances' + annotations['images'] = [] + annotations['categories'] = [{"supercategory": "pedestrian", + "name": "pedestrian", + "id": 1}] + annotations['annotations'] = [] + + annotations_dir = os.path.join(os.path.join(DATA_ROOT, 'annotations')) + if not os.path.isdir(annotations_dir): + os.mkdir(annotations_dir) + annotation_file = os.path.join(annotations_dir, f'{split_name}.json') + + # IMAGE FILES + img_id = 0 + + seqs = sorted(os.listdir(root_split_path)) + + if seqs_names is not None: + seqs = [s for s in seqs if s in seqs_names] + annotations['sequences'] = seqs + annotations['frame_range'] = frame_range + print(split_name, seqs) + + for seq in seqs: + # CONFIG FILE + config = configparser.ConfigParser() + config_file = os.path.join(root_split_path, seq, 'seqinfo.ini') + + if os.path.isfile(config_file): + config.read(config_file) + img_width = int(config['Sequence']['imWidth']) + img_height = int(config['Sequence']['imHeight']) + seq_length = int(config['Sequence']['seqLength']) + else: + img_width = MOT_15_SEQS_INFO[seq]['img_width'] + img_height = MOT_15_SEQS_INFO[seq]['img_height'] + seq_length = MOT_15_SEQS_INFO[seq]['seq_length'] + + seg_list_dir = sorted(os.listdir(os.path.join(root_split_path, seq, 'img1'))) ## !! fix + start_frame = int(frame_range['start'] * seq_length) + end_frame = int(frame_range['end'] * seq_length) + seg_list_dir = seg_list_dir[start_frame: end_frame] + + print(f"{seq}: {len(seg_list_dir)}/{seq_length}") + seq_length = len(seg_list_dir) + + for i, img in enumerate(sorted(seg_list_dir)): + + if i == 0: + first_frame_image_id = img_id + + annotations['images'].append({"file_name": f"{seq}_{img}", + "height": img_height, + "width": img_width, + "id": img_id, + "frame_id": i, + "seq_length": seq_length, + "first_frame_image_id": first_frame_image_id}) + + img_id += 1 + + os.symlink(os.path.join(os.getcwd(), root_split_path, seq, 'img1', img), + os.path.join(coco_dir, f"{seq}_{img}")) + + # GT + annotation_id = 0 + img_file_name_to_id = { + img_dict['file_name']: img_dict['id'] + for img_dict in annotations['images']} + for seq in seqs: + # GT FILE + gt_file_path = os.path.join(root_split_path, seq, 'gt', 'gt.txt') + if mots: + gt_file_path = os.path.join( + root_split_mots_path, + seq.replace('MOT17', 'MOTS20'), + 'gt', + 'gt.txt') + if not os.path.isfile(gt_file_path): + continue + + seq_annotations = [] + if mots: + mask_objects_per_frame = load_mots_gt(gt_file_path) + for frame_id, mask_objects in mask_objects_per_frame.items(): + for mask_object in mask_objects: + # class_id = 1 is car + # class_id = 2 is pedestrian + # class_id = 10 IGNORE + if mask_object.class_id == 1: + continue + + bbox = rletools.toBbox(mask_object.mask) + bbox = [int(c) for c in bbox] + area = bbox[2] * bbox[3] + image_id = img_file_name_to_id.get(f"{seq}_{frame_id:06d}.jpg", None) + if image_id is None: + continue + + segmentation = { + 'size': mask_object.mask['size'], + 'counts': mask_object.mask['counts'].decode(encoding='UTF-8')} + + annotation = { + "id": annotation_id, + "bbox": bbox, + "image_id": image_id, + "segmentation": segmentation, + "ignore": mask_object.class_id == 10, + "visibility": 1.0, + "area": area, + "iscrowd": 0, + "seq": seq, + "category_id": annotations['categories'][0]['id'], + "track_id": mask_object.track_id} + + seq_annotations.append(annotation) + annotation_id += 1 + + annotations['annotations'].extend(seq_annotations) + else: + + seq_annotations_per_frame = {} + with open(gt_file_path, "r") as gt_file: + reader = csv.reader(gt_file, delimiter=' ' if mots else ',') + + for row in reader: + if int(row[6]) == 1 and (seq in MOT_15_SEQS_INFO or int(row[7]) == 1): + bbox = [float(row[2]), float(row[3]), float(row[4]), float(row[5])] + bbox = [int(c) for c in bbox] + + area = bbox[2] * bbox[3] + visibility = float(row[8]) + frame_id = int(row[0]) + image_id = img_file_name_to_id.get(f"{seq}_{frame_id:06d}.jpg", None) + if image_id is None: + continue + track_id = int(row[1]) + + + annotation = { + "id": annotation_id, + "bbox": bbox, + "image_id": image_id, + "segmentation": [], + "ignore": 0 if visibility > VIS_THRESHOLD else 1, + "visibility": visibility, + "area": area, + "iscrowd": 0, + "seq": seq, + "category_id": annotations['categories'][0]['id'], + "track_id": track_id} + + seq_annotations.append(annotation) + if frame_id not in seq_annotations_per_frame: + seq_annotations_per_frame[frame_id] = [] + seq_annotations_per_frame[frame_id].append(annotation) + + annotation_id += 1 + + annotations['annotations'].extend(seq_annotations) + + #change ignore based on MOTS mask + if mots_vis: + gt_file_mots = os.path.join( + root_split_mots_path, + seq.replace('MOT17', 'MOTS20'), + 'gt', + 'gt.txt') + if os.path.isfile(gt_file_mots): + mask_objects_per_frame = load_mots_gt(gt_file_mots) + + for frame_id, frame_annotations in seq_annotations_per_frame.items(): + mask_objects = mask_objects_per_frame[frame_id] + mask_object_bboxes = [rletools.toBbox(obj.mask) for obj in mask_objects] + mask_object_bboxes = torch.tensor(mask_object_bboxes).float() + + frame_boxes = [a['bbox'] for a in frame_annotations] + frame_boxes = torch.tensor(frame_boxes).float() + + # x,y,w,h --> x,y,x,y + frame_boxes[:, 2:] += frame_boxes[:, :2] + mask_object_bboxes[:, 2:] += mask_object_bboxes[:, :2] + + mask_iou = box_iou(mask_object_bboxes, frame_boxes) + + mask_indices, frame_indices = linear_sum_assignment(-mask_iou) + for m_i, f_i in zip(mask_indices, frame_indices): + if mask_iou[m_i, f_i] < 0.5: + continue + + if not frame_annotations[f_i]['visibility']: + frame_annotations[f_i]['ignore'] = 0 + + # max objs per image + num_objs_per_image = {} + for anno in annotations['annotations']: + image_id = anno["image_id"] + + if image_id in num_objs_per_image: + num_objs_per_image[image_id] += 1 + else: + num_objs_per_image[image_id] = 1 + + print(f'max objs per image: {max(list(num_objs_per_image.values()))}') + + with open(annotation_file, 'w') as anno_file: + json.dump(annotations, anno_file, indent=4) + + +def check_coco_from_mot(split='train'): + """ + Visualize generated COCO data. Only used for debugging. + """ + coco_dir = os.path.join(DATA_ROOT, split) + annotation_file = os.path.join(coco_dir, 'annotations.json') + + coco = COCO(annotation_file) + cat_ids = coco.getCatIds(catNms=['pedestrian']) + img_ids = coco.getImgIds(catIds=cat_ids) + + index = np.random.randint(0, len(img_ids)) + img = coco.loadImgs(img_ids[index])[0] + + i = io.imread(os.path.join(coco_dir, img['file_name'])) + + plt.imshow(i) + plt.axis('off') + ann_ids = coco.getAnnIds(imgIds=img['id'], catIds=cat_ids, iscrowd=None) + anns = coco.loadAnns(ann_ids) + coco.showAnns(anns, draw_bbox=True) + plt.savefig('annotations.png') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate COCO from MOT.') + parser.add_argument('--mots20', action='store_true') + args = parser.parse_args() + + mot15_seqs_names = list(MOT_15_SEQS_INFO.keys()) + + if args.mots20: + # + # MOTS20 + # + + # TRAIN SET + generate_coco_from_mot( + 'mots20_train_coco', + seqs_names=['MOTS20-02', 'MOTS20-05', 'MOTS20-09', 'MOTS20-11'], + mots=True) + + # TRAIN SPLITS + for i in range(4): + train_seqs = ['MOTS20-02', 'MOTS20-05', 'MOTS20-09', 'MOTS20-11'] + val_seqs = train_seqs.pop(i) + + generate_coco_from_mot( + f'mots20_train_{i + 1}_coco', + seqs_names=train_seqs, mots=True) + generate_coco_from_mot( + f'mots20_val_{i + 1}_coco', + seqs_names=val_seqs, mots=True) + else: + # + # MOT17 + # + + # CROSS VAL SPLIT 1 + generate_coco_from_mot( + 'mot17_train_cross_val_1_coco', + seqs_names=['MOT17-04-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', 'MOT17-11-FRCNN']) + generate_coco_from_mot( + 'mot17_val_cross_val_1_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-10-FRCNN', 'MOT17-13-FRCNN']) + + # CROSS VAL SPLIT 2 + generate_coco_from_mot( + 'mot17_train_cross_val_2_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-13-FRCNN']) + generate_coco_from_mot( + 'mot17_val_cross_val_2_coco', + seqs_names=['MOT17-04-FRCNN', 'MOT17-11-FRCNN']) + + # CROSS VAL SPLIT 3 + generate_coco_from_mot( + 'mot17_train_cross_val_3_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-04-FRCNN', 'MOT17-10-FRCNN', 'MOT17-11-FRCNN', 'MOT17-13-FRCNN']) + generate_coco_from_mot( + 'mot17_val_cross_val_3_coco', + seqs_names=['MOT17-05-FRCNN', 'MOT17-09-FRCNN']) + + # CROSS VAL FRAME SPLIT + generate_coco_from_mot( + 'mot17_train_cross_val_frame_0_0_to_0_25_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-04-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-11-FRCNN', 'MOT17-13-FRCNN'], + frame_range={'start': 0, 'end': 0.25}) + generate_coco_from_mot( + 'mot17_train_cross_val_frame_0_0_to_0_5_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-04-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-11-FRCNN', 'MOT17-13-FRCNN'], + frame_range={'start': 0, 'end': 0.5}) + generate_coco_from_mot( + 'mot17_train_cross_val_frame_0_5_to_1_0_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-04-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-11-FRCNN', 'MOT17-13-FRCNN'], + frame_range={'start': 0.5, 'end': 1.0}) + + generate_coco_from_mot( + 'mot17_train_cross_val_frame_0_75_to_1_0_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-04-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-11-FRCNN', 'MOT17-13-FRCNN'], + frame_range={'start': 0.75, 'end': 1.0}) + + # TRAIN SET + generate_coco_from_mot( + 'mot17_train_coco', + seqs_names=['MOT17-02-FRCNN', 'MOT17-04-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', + 'MOT17-10-FRCNN', 'MOT17-11-FRCNN', 'MOT17-13-FRCNN']) + + for i in range(0, 7): + train_seqs = [ + 'MOT17-02-FRCNN', 'MOT17-04-FRCNN', 'MOT17-05-FRCNN', 'MOT17-09-FRCNN', + 'MOT17-10-FRCNN', 'MOT17-11-FRCNN', 'MOT17-13-FRCNN'] + val_seqs = train_seqs.pop(i) + + generate_coco_from_mot( + f'mot17_train_{i + 1}_coco', + seqs_names=train_seqs) + generate_coco_from_mot( + f'mot17_val_{i + 1}_coco', + seqs_names=val_seqs) diff --git a/src/generate_coco_from_vidhoi.py b/src/generate_coco_from_vidhoi.py new file mode 100644 index 0000000..6afb307 --- /dev/null +++ b/src/generate_coco_from_vidhoi.py @@ -0,0 +1,259 @@ +import json +import os +import pickle +import torch +import numpy as np +import warnings +from tqdm import tqdm +import argparse + +# generate VidHOI from original VidOR annotations +def convert_vidor_to_ava_label(annot_dir): + frame_annots = [] + used_video_dict = set() + + for folder in tqdm(os.listdir(annot_dir)): + for video_json in os.listdir(os.path.join(annot_dir, folder)): + with open(os.path.join(annot_dir, folder, video_json), 'r') as f: + annot = json.load(f) + + if abs(annot['fps'] - 29.97) < 0.1: + fps = 30 + elif annot['fps'] - 24 < 1.01: # fps 24, 25 + fps = 24 + else: + raise(f"Invalid fps={annot['fps']}") + + for i in range(annot['frame_count']): + if (i - (fps // 2)) % fps != 0: # 1 sample/sec + continue + + idx = i-1 + for rel in annot['relation_instances']: + if rel['begin_fid'] <= idx < rel['end_fid'] \ + and annot['subject/objects'][rel['subject_tid']]['category'] in human_categories \ + and rel['predicate'] in pred_categories: + frame_annot = annot['trajectories'][idx] + + person_found = object_found = False + for ann in frame_annot: + if ann['tid'] == rel['subject_tid']: + person_annot, person_found = ann, True + elif ann['tid'] == rel['object_tid']: + object_annot, object_found = ann, True + if person_found and object_found: + break + + frame_annots.append({ + 'video_folder': folder, + 'video_id': annot['video_id'], + 'frame_id': str(f'{idx+1:06d}'), # real frame index start from 1 + 'video_fps': fps, # annot['fps'], + 'height': annot['height'], + 'width': annot['width'], + # 'middle_frame_timestamp': i // fps + 1, + 'person_box': person_annot['bbox'], + 'object_box': object_annot['bbox'], + 'person_id': person_annot['tid'], + 'object_id': object_annot['tid'], + 'object_class': obj_to_idx['person'] if annot['subject/objects'][rel['object_tid']]['category'] in human_categories else obj_to_idx[annot['subject/objects'][rel['object_tid']]['category']], + 'action_class': pred_to_idx[rel['predicate']], + }) + + used_video_dict.add(folder + '/' + annot['video_id']) + return frame_annots, used_video_dict + +def dump_frames(args, relation_instance_annotations, split='train'): + video_dir = f"{args.video_dir}/{split}" + frame_dir = f"{args.frame_dir}/{split}" + + # Create video to frames mapping + video2frames, video2fps = {}, {} + for ann in relation_instance_annotations: # hoi items + video_key = f'{ann["video_folder"]}/{ann["video_id"]}' + frame = ann['frame_id'] + if video_key not in video2frames: + video2frames[video_key] = set() + video2frames[video_key].add(frame) + video2fps[video_key] = ann['video_fps'] + print(f"Total {split} #frames (with relations): {sum([len(v) for k, v in video2frames.items()])}") + + # For each video, dump frames. + print('Dumping video frames...') + for v in tqdm(video2frames): + curr_frame_dir = os.path.join(frame_dir, v) + + # keep frames with even sample rates + keep_frames = sorted([f'{x}' for x in list(video2frames[v])]) + keep_frames = [f"{idx:06d}.png" for idx in list(range(int(keep_frames[0]), int(keep_frames[-1])+1, video2fps[v]))] + + if not os.path.exists(curr_frame_dir): + os.makedirs(curr_frame_dir) + # Use ffmpeg to extract frames. Different versions of ffmpeg may generate slightly different frames. + # We used ffmpeg 2.8.15 to dump our frames. + os.system('ffmpeg -loglevel panic -i %s/%s.mp4 %s/%%06d.png' % (video_dir, v, curr_frame_dir)) + + # only keep the annotated frames included in frame_list.txt + frames_to_delete = set(os.listdir(curr_frame_dir)) - set(keep_frames) + for frame in frames_to_delete: + os.remove(os.path.join(curr_frame_dir, frame)) + elif set(os.listdir(curr_frame_dir)) != set(keep_frames): + print(f'Update dumping {curr_frame_dir}.') + os.system('ffmpeg -loglevel panic -i %s/%s.mp4 %s/%%06d.png' % (video_dir, v, curr_frame_dir)) + frames_to_delete = set(os.listdir(curr_frame_dir)) - set(keep_frames) + for frame in frames_to_delete: + os.remove(os.path.join(curr_frame_dir, frame)) + else: + print(f'Skip dumping {curr_frame_dir}.') + +def convert_to_coco_annotations(args, relation_instance_annotations, split='train', video_limit=None): + # video and key frames + relation_annotations_dict, video2fps = {}, {} + for rel_instance in relation_instance_annotations: + video_key = f"{rel_instance['video_folder']}/{rel_instance['video_id']}" + if video_key not in relation_annotations_dict: + relation_annotations_dict[video_key] = {} + if rel_instance['frame_id'] not in relation_annotations_dict[video_key]: + relation_annotations_dict[video_key][rel_instance['frame_id']] = [] + relation_annotations_dict[video_key][rel_instance['frame_id']].append({ + 'subject_tid': rel_instance['person_id'], + 'subject_class': 0, + 'object_tid': rel_instance['object_id'], + 'object_class': rel_instance['object_class'], + 'predicate': rel_instance['action_class'], + 'predicate_name': idx_to_pred[rel_instance['action_class']] + }) + video2fps[video_key] = rel_instance['video_fps'] + + saving_video_keys = list(relation_annotations_dict.keys()) + if video_limit is not None: saving_video_keys = sorted(saving_video_keys)[:video_limit] + # print(saving_video_keys) + + # to coco format + annotations_coco_format = { + 'type': 'instances', + 'images': [], + 'categories': [{'id': id, 'name': c, 'supercategory': c} for id, c in enumerate(obj_categories)], + 'annotations': [], + 'predicate_categories': [{'id': id, 'name': c, 'supercategory': c} for id, c in enumerate(pred_categories)], + 'relation_annotations': relation_annotations_dict, + 'sequences': saving_video_keys, + 'sequence_startend_image_ids': [] + } + + image_id, annotation_id = 0, 0 + for video_key in tqdm(saving_video_keys): + frames_with_rels = sorted(list(relation_annotations_dict[video_key])) + key_frames = [f"{idx:06d}" for idx in list(range(int(frames_with_rels[0]), int(frames_with_rels[-1])+1, video2fps[video_key]))] + + with open(f"{args.vidor_orig_annotation_dir}/{split}/{video_key}.json", 'r') as f: + orig_vidor_anntations = json.load(f) + + track_tid2infos = {x['tid']: x for x in orig_vidor_anntations['subject/objects']} + first_frame_image_id = image_id + for idx, frame_key in enumerate(key_frames): + annotations_coco_format['images'].append({ + 'id': image_id, + 'file_name': f"{video_key}/{frame_key}.png", + 'frame_id': idx, + 'first_frame_image_id': first_frame_image_id, + 'video_key': video_key, + 'frame_key': frame_key + }) + + # box instances in frames + frame_boxes = orig_vidor_anntations['trajectories'][int(frame_key)-1] + for box in frame_boxes: + box_cat = track_tid2infos[box['tid']]['category'] + xywh = [box['bbox']['xmin'], box['bbox']['ymin'], + box['bbox']['xmax']-box['bbox']['xmin']+1, + box['bbox']['ymax']-box['bbox']['ymin']+1] + assert xywh[2] > 0 and xywh[3] > 0 + annotations_coco_format['annotations'].append({ + 'id': annotation_id, + 'bbox': xywh, + 'image_id': image_id, + "segmentation": [], + "ignore": False, + "visibility": True, + "area": xywh[2] * xywh[3], + "iscrowd": 0, + "seq": video_key, + "category_id": 1 if box_cat in human_categories else obj_categories.index(box_cat), + "track_id": box['tid'], + }) + annotation_id += 1 + + image_id += 1 + annotations_coco_format['sequence_startend_image_ids'].append((first_frame_image_id, image_id-1)) + + # save annotations + annotation_file = f'{args.annotation_dir}/{split}_cocofmt.json' + if video_limit is not None: annotation_file = f'{args.annotation_dir}/{split}_v{video_limit}_cocofmt.json' + with open(annotation_file, 'w') as anno_file: + json.dump(annotations_coco_format, anno_file, indent=4) + print(f'Saved {split} annotaions to {annotation_file}') + print(f"{split} #keyframe (all): {len(annotations_coco_format['images'])}") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Dump frames") + parser.add_argument("--task", default="convert_coco_annotations", help="dump_frames | convert_coco_annotations") + parser.add_argument("--video_dir", default="data/VidHOI/VidOR/videos", help="Folder containing VidOR videos.") + parser.add_argument("--frame_dir", default="data/VidHOI/frames", help="Root folder containing frames to be dumped.") + parser.add_argument("--annotation_dir", default="data/VidHOI/VidHOI_annotations", help=("Folder containing VidHOI annotation files")) + parser.add_argument("--vidor_orig_annotation_dir", default="data/VidHOI/VidOR/annotations", help=("Original annotations of VidOR")) + args = parser.parse_args() + + # load meta_infos + human_categories = ['adult', 'child', 'baby'] + with open('data/VidHOI/VidHOI_annotations/obj_categories.json', 'r') as f: + obj_categories = json.load(f) + with open('data/VidHOI/VidHOI_annotations/obj_to_idx.pkl', 'rb') as f: + obj_to_idx = pickle.load(f) # exclude BG, person=0 + with open('data/VidHOI/VidHOI_annotations/idx_to_obj.pkl', 'rb') as f: + idx_to_obj = pickle.load(f) + print(f"#objects: {len(obj_categories)}") + obj_categories.insert(0, '__background__') + + with open('data/VidHOI/VidHOI_annotations/pred_categories.json', 'r') as f: + pred_categories = json.load(f) + with open('data/VidHOI/VidHOI_annotations/pred_to_idx.pkl', 'rb') as f: + pred_to_idx = pickle.load(f) + with open('data/VidHOI/VidHOI_annotations/idx_to_pred.pkl', 'rb') as f: + idx_to_pred = pickle.load(f) + print(f"#predicates: {len(pred_categories)}") + + ###### load or generate relation instances of VidHOI from VidOR ###### + # train + train_frame_ann_file = f'{args.annotation_dir}/train_frame_annots.json' + if os.path.isfile(train_frame_ann_file): + with open(train_frame_ann_file, 'r') as f: + train_relation_annots = json.load(f) + else: + train_relation_annots, _ = convert_vidor_to_ava_label(f'{args.vidor_orig_annotation_dir}/train') + with open(train_frame_ann_file, 'w') as f: + json.dump(train_relation_annots, f) + print(f'train hoi+hhi: {len(train_relation_annots)}') + + # val + val_frame_ann_file = f'{args.annotation_dir}/val_frame_annots.json' + if os.path.isfile(val_frame_ann_file): + with open(val_frame_ann_file, 'r') as f: + val_relation_annots = json.load(f) + else: + val_relation_annots, _ = convert_vidor_to_ava_label(f'{args.vidor_orig_annotation_dir}/validation') + with open(val_frame_ann_file, 'w') as f: + json.dump(val_relation_annots, f) + print(f'val hoi+hhi: {len(val_relation_annots)}') + + # pre-processing tasks + if args.task == 'dump_frames': + dump_frames(args, train_relation_annots, split='train') + dump_frames(args, val_relation_annots, split='validation') + elif args.task == 'convert_coco_annotations': + # convert_to_coco_annotations(args, train_relation_annots, split='train', video_limit=30) + # convert_to_coco_annotations(args, val_relation_annots, split='validation', video_limit=10) + convert_to_coco_annotations(args, train_relation_annots, split='train') + convert_to_coco_annotations(args, val_relation_annots, split='validation') + else: + raise(f'Unsupported task: {args.task}') diff --git a/src/parse_mot_results_to_tex.py b/src/parse_mot_results_to_tex.py new file mode 100644 index 0000000..f05af3d --- /dev/null +++ b/src/parse_mot_results_to_tex.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Parse MOT results and generate a LaTeX table. +""" + +MOTS = False +F_CONTENT = """ + MOTA IDF1 MOTP MT ML FP FN Recall Precision FAF IDSW Frag + MOT17-01-DPM 41.6 44.2 77.1 5 8 496 3252 49.6 86.6 1.1 22 58 + MOT17-01-FRCNN 41.0 42.1 77.1 6 9 571 3207 50.3 85.0 1.3 25 61 + MOT17-01-SDP 41.8 44.3 76.8 7 8 612 3112 51.8 84.5 1.4 27 65 + MOT17-03-DPM 79.3 71.6 79.1 94 8 1142 20297 80.6 98.7 0.8 191 525 + MOT17-03-FRCNN 79.6 72.7 79.1 93 7 1234 19945 80.9 98.6 0.8 180 508 + MOT17-03-SDP 80.0 72.0 79.0 93 8 1223 19530 81.3 98.6 0.8 181 526 + MOT17-06-DPM 54.8 42.0 79.5 54 63 314 4839 58.9 95.7 0.3 175 244 + MOT17-06-FRCNN 55.6 42.9 79.3 57 59 363 4676 60.3 95.1 0.3 190 264 + MOT17-06-SDP 55.5 43.8 79.3 56 61 354 4712 60.0 95.2 0.3 181 262 + MOT17-07-DPM 44.8 42.0 76.6 11 16 1322 7851 53.5 87.2 2.6 147 275 + MOT17-07-FRCNN 45.5 41.5 76.6 13 15 1263 7785 53.9 87.8 2.5 156 289 + MOT17-07-SDP 45.2 42.4 76.6 13 15 1332 7775 54.0 87.3 2.7 147 279 + MOT17-08-DPM 26.5 32.2 83.0 11 37 378 15066 28.7 94.1 0.6 88 146 + MOT17-08-FRCNN 26.5 31.9 83.1 11 36 332 15113 28.5 94.8 0.5 89 141 + MOT17-08-SDP 26.6 32.3 83.1 11 36 350 15067 28.7 94.5 0.6 91 147 + MOT17-12-DPM 46.1 53.1 82.7 16 45 207 4434 48.8 95.3 0.2 30 50 + MOT17-12-FRCNN 46.1 52.6 82.6 15 45 197 4443 48.7 95.5 0.2 30 48 + MOT17-12-SDP 46.0 53.0 82.6 16 45 221 4426 48.9 95.0 0.2 30 52 + MOT17-14-DPM 31.6 36.6 74.8 13 78 636 11812 36.1 91.3 0.8 196 331 + MOT17-14-FRCNN 31.6 37.6 74.6 13 77 780 11653 37.0 89.8 1.0 202 350 + MOT17-14-SDP 31.7 37.1 74.7 13 76 749 11677 36.8 90.1 1.0 205 344 + OVERALL 61.5 59.6 78.9 621 752 14076 200672 64.4 96.3 0.8 2583 4965 + """ + +# MOTS = True +# F_CONTENT = """ +# sMOTSA MOTSA MOTSP IDF1 MT ML MTR PTR MLR GT TP FP FN Rcll Prcn FM FMR IDSW IDSWR +# MOTS20-01 59.79 79.56 77.60 68.00 10 0 83.33 16.67 0.00 12 2742 255 364 88.28 91.49 37 41.91 16 18.1 +# MOTS20-06 63.91 78.72 82.85 65.14 115 22 60.53 27.89 11.58 190 8479 595 1335 86.40 93.44 218 252.32 158 182.9 +# MOTS20-07 43.17 58.52 76.59 53.60 15 17 25.86 44.83 29.31 58 8445 834 4433 65.58 91.01 177 269.91 75 114.4 +# MOTS20-12 62.04 74.64 84.93 76.83 41 9 60.29 26.47 13.24 68 5408 549 1063 83.57 90.78 76 90.94 29 34.7 +# OVERALL 54.86 69.92 80.62 63.58 181 48 55.18 30.18 14.63 328 25074 2233 7195 77.70 91.82 508 653.77 278 357.8 +# """ + +if __name__ == '__main__': + # remove empty lines at start and beginning of F_CONTENT + F_CONTENT = F_CONTENT.strip() + F_CONTENT = F_CONTENT.splitlines() + + start_ixs = range(1, len(F_CONTENT) - 1, 3) + if MOTS: + start_ixs = range(1, len(F_CONTENT) - 1) + + metrics_res = {} + + for i in range(len(['DPM', 'FRCNN', 'SDP'])): + for start in start_ixs: + f_list = F_CONTENT[start + i].strip().split('\t') + metrics_res[f_list[0]] = f_list[1:] + + if MOTS: + break + + metrics_names = F_CONTENT[0].replace('\n', '').split() + + print(metrics_names) + + metrics_res['ALL'] = F_CONTENT[-1].strip().split('\t')[1:] + + for full_seq_name, data in metrics_res.items(): + seq_name = '-'.join(full_seq_name.split('-')[:2]) + detection_name = full_seq_name.split('-')[-1] + + if MOTS: + print(f"{seq_name} & " + f"{float(data[metrics_names.index('sMOTSA')]):.1f} & " + f"{float(data[metrics_names.index('IDF1')]):.1f} & " + f"{float(data[metrics_names.index('MOTSA')]):.1f} & " + f"{data[metrics_names.index('FP')]} & " + f"{data[metrics_names.index('FN')]} & " + f"{data[metrics_names.index('IDSW')]} \\\\") + else: + print(f"{seq_name} & {detection_name} & " + f"{float(data[metrics_names.index('MOTA')]):.1f} & " + f"{float(data[metrics_names.index('IDF1')]):.1f} & " + f"{data[metrics_names.index('MT')]} & " + f"{data[metrics_names.index('ML')]} & " + f"{data[metrics_names.index('FP')]} & " + f"{data[metrics_names.index('FN')]} & " + f"{data[metrics_names.index('IDSW')]} \\\\") diff --git a/src/run_with_submitit.py b/src/run_with_submitit.py new file mode 100644 index 0000000..d2bb262 --- /dev/null +++ b/src/run_with_submitit.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +A script to run multinode training with submitit. +""" +import os +import sys +import uuid +from pathlib import Path +from argparse import Namespace + +import sacred +import submitit + +import train +from trackformer.util.misc import nested_dict_to_namespace + +WORK_DIR = str(Path(__file__).parent.absolute()) + + +ex = sacred.Experiment('submit', ingredients=[train.ex]) +ex.add_config('cfgs/submit.yaml') + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/storage/slurm").is_dir(): + path = Path(f"/storage/slurm/{user}/runs") + path.mkdir(exist_ok=True) + return path + raise RuntimeError("No shared folder available") + + +def get_init_file() -> Path: + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer: + def __init__(self, args: Namespace) -> None: + self.args = args + + def __call__(self) -> None: + sys.path.append(WORK_DIR) + + import train + self._setup_gpu_args() + train.train(self.args) + + def checkpoint(self) -> submitit.helpers.DelayedSubmission: + import os + + import submitit + + self.args.dist_url = get_init_file().as_uri() + checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") + if os.path.exists(checkpoint_file): + self.args.resume = checkpoint_file + self.args.resume_optim = True + self.args.resume_vis = True + self.args.load_mask_head_from_model = None + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self) -> None: + from pathlib import Path + + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) + print(self.args.output_dir) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(args: Namespace): + # Note that the folder will depend on the job_id, to easily track experiments + if args.job_dir == "": + args.job_dir = get_shared_folder() / "%j" + + executor = submitit.AutoExecutor( + folder=args.job_dir, cluster=args.cluster, slurm_max_num_timeout=30) + + # cluster setup is defined by environment variables + num_gpus_per_node = args.num_gpus + nodes = args.nodes + timeout_min = args.timeout + + if args.slurm_gres: + slurm_gres = args.slurm_gres + else: + slurm_gres = f'gpu:{num_gpus_per_node},VRAM:{args.vram}' + + executor.update_parameters( + mem_gb=args.mem_per_gpu * num_gpus_per_node, + # gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=2, + nodes=nodes, + timeout_min=timeout_min, # max is 60 * 72, + slurm_partition=args.slurm_partition, + slurm_constraint=args.slurm_constraint, + slurm_comment=args.slurm_comment, + slurm_exclude=args.slurm_exclude, + slurm_gres=slurm_gres + ) + + executor.update_parameters(name="fair_track") + + args.train.dist_url = get_init_file().as_uri() + # args.output_dir = args.job_dir + + trainer = Trainer(args.train) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id) + + if args.cluster == 'debug': + job.wait() + + +@ex.main +def load_config(_config, _run): + """ We use sacred only for config loading from YAML files. """ + sacred.commands.print_config(_run) + + +if __name__ == '__main__': + # TODO: hierachical Namespacing for nested dict + config = ex.run_commandline().config + args = nested_dict_to_namespace(config) + # args.train = Namespace(**config['train']) + main(args) diff --git a/src/track.py b/src/track.py new file mode 100644 index 0000000..26fe923 --- /dev/null +++ b/src/track.py @@ -0,0 +1,207 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import os +import sys +import time +from os import path as osp + +import motmetrics as mm +import numpy as np +import sacred +import torch +import tqdm +import yaml +from torch.utils.data import DataLoader + +from trackformer.datasets.tracking import TrackDatasetFactory +from trackformer.models import build_model +from trackformer.models.tracker import Tracker +from trackformer.util.misc import nested_dict_to_namespace +from trackformer.util.track_utils import (evaluate_mot_accums, get_mot_accum, + interpolate_tracks, plot_sequence) + +mm.lap.default_solver = 'lap' + +ex = sacred.Experiment('track') +ex.add_config('cfgs/track.yaml') +ex.add_named_config('reid', 'cfgs/track_reid.yaml') + + +@ex.automain +def main(seed, dataset_name, obj_detect_checkpoint_file, tracker_cfg, + write_images, output_dir, interpolate, verbose, load_results_dir, + data_root_dir, generate_attention_maps, frame_range, + _config, _log, _run, obj_detector_model=None): + if write_images: + assert output_dir is not None + + # obj_detector_model is only provided when run as evaluation during + # training. in that case we omit verbose outputs. + if obj_detector_model is None: + sacred.commands.print_config(_run) + + # set all seeds + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + + if output_dir is not None: + if not osp.exists(output_dir): + os.makedirs(output_dir) + + yaml.dump( + _config, + open(osp.join(output_dir, 'track.yaml'), 'w'), + default_flow_style=False) + + ########################## + # Initialize the modules # + ########################## + + # object detection + if obj_detector_model is None: + obj_detect_config_path = os.path.join( + os.path.dirname(obj_detect_checkpoint_file), + 'config.yaml') + obj_detect_args = nested_dict_to_namespace(yaml.unsafe_load(open(obj_detect_config_path))) + img_transform = obj_detect_args.img_transform + obj_detector, _, obj_detector_post = build_model(obj_detect_args) + + print(f'Load model from {obj_detect_checkpoint_file}') + obj_detect_checkpoint = torch.load( + obj_detect_checkpoint_file, map_location=lambda storage, loc: storage) + + obj_detect_state_dict = obj_detect_checkpoint['model'] + # obj_detect_state_dict = { + # k: obj_detect_state_dict[k] if k in obj_detect_state_dict + # else v + # for k, v in obj_detector.state_dict().items()} + + obj_detect_state_dict = { + k.replace('detr.', ''): v + for k, v in obj_detect_state_dict.items() + if 'track_encoding' not in k} + + obj_detector.load_state_dict(obj_detect_state_dict) + if 'epoch' in obj_detect_checkpoint: + _log.info(f"INIT object detector [EPOCH: {obj_detect_checkpoint['epoch']}]") + + obj_detector.cuda() + else: + obj_detector = obj_detector_model['model'] + obj_detector_post = obj_detector_model['post'] + img_transform = obj_detector_model['img_transform'] + + if hasattr(obj_detector, 'tracking'): + obj_detector.tracking() + + track_logger = None + if verbose: + track_logger = _log.info + tracker = Tracker( + obj_detector, obj_detector_post, tracker_cfg, + generate_attention_maps, track_logger) + print(f'Tracker_config: {tracker_cfg}') + + time_total = 0 + num_frames = 0 + mot_accums = [] + dataset = TrackDatasetFactory( + dataset_name, root_dir=data_root_dir, img_transform=img_transform) + + for seq in dataset: + tracker.reset() + + _log.info(f"------------------") + _log.info(f"TRACK SEQ: {seq}") + + start_frame = int(frame_range['start'] * len(seq)) + end_frame = int(frame_range['end'] * len(seq)) + + seq_loader = DataLoader( + torch.utils.data.Subset(seq, range(start_frame, end_frame))) + + num_frames += len(seq_loader) + + results = seq.load_results(load_results_dir) + + if not results: + start = time.time() + + for frame_id, frame_data in enumerate(tqdm.tqdm(seq_loader, file=sys.stdout)): + with torch.no_grad(): + tracker.step(frame_data) + + results = tracker.get_results() + + time_total += time.time() - start + + _log.info(f"NUM TRACKS: {len(results)} ReIDs: {tracker.num_reids}") + _log.info(f"RUNTIME: {time.time() - start :.2f} s") + + if interpolate: + results = interpolate_tracks(results) + + if output_dir is not None: + _log.info(f"WRITE RESULTS") + seq.write_results(results, output_dir) + else: + _log.info("LOAD RESULTS") + + if seq.no_gt: + _log.info("NO GT AVAILBLE") + else: + mot_accum = get_mot_accum(results, seq_loader) + mot_accums.append(mot_accum) + + if verbose: + mot_events = mot_accum.mot_events + reid_events = mot_events[mot_events['Type'] == 'SWITCH'] + match_events = mot_events[mot_events['Type'] == 'MATCH'] + + switch_gaps = [] + for index, event in reid_events.iterrows(): + frame_id, _ = index + match_events_oid = match_events[match_events['OId'] == event['OId']] + match_events_oid_earlier = match_events_oid[ + match_events_oid.index.get_level_values('FrameId') < frame_id] + + if not match_events_oid_earlier.empty: + match_events_oid_earlier_frame_ids = \ + match_events_oid_earlier.index.get_level_values('FrameId') + last_occurrence = match_events_oid_earlier_frame_ids.max() + switch_gap = frame_id - last_occurrence + switch_gaps.append(switch_gap) + + switch_gaps_hist = None + if switch_gaps: + switch_gaps_hist, _ = np.histogram( + switch_gaps, bins=list(range(0, max(switch_gaps) + 10, 10))) + switch_gaps_hist = switch_gaps_hist.tolist() + + _log.info(f'SWITCH_GAPS_HIST (bin_width=10): {switch_gaps_hist}') + + if output_dir is not None and write_images: + _log.info("PLOT SEQ") + plot_sequence( + results, seq_loader, osp.join(output_dir, dataset_name, str(seq)), + write_images, generate_attention_maps) + + if time_total: + _log.info(f"RUNTIME ALL SEQS (w/o EVAL or IMG WRITE): " + f"{time_total:.2f} s for {num_frames} frames " + f"({num_frames / time_total:.2f} Hz)") + + if obj_detector_model is None: + _log.info(f"EVAL:") + + summary, str_summary = evaluate_mot_accums( + mot_accums, + [str(s) for s in dataset if not s.no_gt]) + + _log.info(f'\n{str_summary}') + + return summary + + return mot_accums diff --git a/src/track_param_search.py b/src/track_param_search.py new file mode 100644 index 0000000..0c1b5de --- /dev/null +++ b/src/track_param_search.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from itertools import product + +import numpy as np + +from track import ex + + +if __name__ == "__main__": + general_tracker_cfg = {'public_detections': False, 'reid_sim_only': True, 'reid_greedy_matching': True} + # general_tracker_cfg = {'public_detections': False,} + + # configs = [ + # {'dataset_name': ["MOT17-02-FRCNN", "MOT17-10-FRCNN", "MOT17-13-FRCNN"], + # 'obj_detect_checkpoint_file': 'models/mot17det_train_cross_val_1_mots_vis_track_bbox_proposals_track_encoding_bbox_proposals_prev_frame_5/checkpoint_best_MOTA.pth'}, + # {'dataset_name': ["MOT17-04-FRCNN", "MOT17-11-FRCNN"], + # 'obj_detect_checkpoint_file': 'models/mot17det_train_cross_val_2_mots_vis_track_bbox_proposals_track_encoding_bbox_proposals_prev_frame_5/checkpoint_best_MOTA.pth'}, + # {'dataset_name': ["MOT17-05-FRCNN", "MOT17-09-FRCNN"], + # 'obj_detect_checkpoint_file': 'models/mot17det_train_cross_val_3_mots_vis_track_bbox_proposals_track_encoding_bbox_proposals_prev_frame_5/checkpoint_best_MOTA.pth'}, + # ] + + configs = [ + {'dataset_name': ["MOT17-02-FRCNN"], + 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_1_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, + {'dataset_name': ["MOT17-04-FRCNN"], + 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_2_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, + {'dataset_name': ["MOT17-05-FRCNN"], + 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_3_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, + {'dataset_name': ["MOT17-09-FRCNN"], + 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_4_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, + {'dataset_name': ["MOT17-10-FRCNN"], + 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_5_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, + {'dataset_name': ["MOT17-11-FRCNN"], + 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_6_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, + {'dataset_name': ["MOT17-13-FRCNN"], + 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_7_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, + ] + + tracker_param_grids = { + 'detection_obj_score_thresh': [0.9], + 'track_obj_score_thresh': [0.8], + 'detection_nms_thresh': [0.7], + 'track_nms_thresh': [0.9], + 'reid_sim_threshold': [0.0], + 'reid_score_thresh': [0.8], + 'inactive_patience': [1]} + + # compute all config combinations + tracker_param_cfgs = [dict(zip(tracker_param_grids, v)) + for v in product(*tracker_param_grids.values())] + + # add empty metric arrays + metrics = ['mota', 'idf1'] + tracker_param_cfgs = [ + {'config': {**general_tracker_cfg, **tracker_cfg}} + for tracker_cfg in tracker_param_cfgs] + + for m in metrics: + for tracker_cfg in tracker_param_cfgs: + tracker_cfg[m] = [] + + total_num_experiments = len(tracker_param_cfgs) * len(configs) + print(f'NUM experiments: {total_num_experiments}') + + # run all tracker config combinations for all experiment configurations + exp_counter = 1 + for config in configs: + for tracker_cfg in tracker_param_cfgs: + print(f"EXPERIMENT: {exp_counter}/{total_num_experiments}") + + config['tracker_cfg'] = tracker_cfg['config'] + run = ex.run(config_updates=config) + eval_summary = run.result + + for m in metrics: + tracker_cfg[m].append(eval_summary[m]['OVERALL']) + + exp_counter += 1 + + # compute mean for all metrices + for m in metrics: + for tracker_cfg in tracker_param_cfgs: + tracker_cfg[m] = np.array(tracker_cfg[m]).mean() + + for cfg in tracker_param_cfgs: + print(cfg['config']) + print([cfg[m] for m in metrics]) + + # compute and plot best metric config + for m in metrics: + best_metric_cfg_idx = np.array( + [cfg[m] for cfg in tracker_param_cfgs]).argmax() + + print(f"BEST {m.upper()} CFG: {tracker_param_cfgs[best_metric_cfg_idx]['config']}") + + # TODO + best_mota_plus_idf1_cfg_idx = np.array( + [cfg['mota'] + cfg['idf1'] for cfg in tracker_param_cfgs]).argmax() + print(f"BEST MOTA PLUS IDF1 CFG: {tracker_param_cfgs[best_mota_plus_idf1_cfg_idx]['config']}") diff --git a/src/trackformer/__init__.py b/src/trackformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/trackformer/datasets/__init__.py b/src/trackformer/datasets/__init__.py new file mode 100644 index 0000000..7a1150c --- /dev/null +++ b/src/trackformer/datasets/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Submodule interface. +""" +from argparse import Namespace +from pycocotools.coco import COCO +from torch.utils.data import Dataset, Subset +from torchvision.datasets import CocoDetection + +from .coco import build as build_coco +from .crowdhuman import build_crowdhuman +from .mot import build_mot, build_mot_crowdhuman +from .actiongenome import build_actiongenome +from .vidhoi import build_vidhoi + +def get_coco_api_from_dataset(dataset: Subset) -> COCO: + """Return COCO class from PyTorch dataset for evaluation with COCO eval.""" + for _ in range(10): + # if isinstance(dataset, CocoDetection): + # break + if isinstance(dataset, Subset): + dataset = dataset.dataset + + if not isinstance(dataset, CocoDetection): + raise NotImplementedError + + return dataset.coco + + +def build_dataset(split: str, args: Namespace) -> Dataset: + """Helper function to build dataset for different splits ('train' or 'val').""" + if args.dataset == 'coco': + dataset = build_coco(split, args) + elif args.dataset == 'coco_person': + dataset = build_coco(split, args, 'person_keypoints') + elif args.dataset == 'mot': + dataset = build_mot(split, args) + elif args.dataset == 'crowdhuman': + dataset = build_crowdhuman(split, args) + elif args.dataset == 'mot_crowdhuman': + dataset = build_mot_crowdhuman(split, args) + elif args.dataset == 'coco_panoptic': + # to avoid making panopticapi required for coco + from .coco_panoptic import build as build_coco_panoptic + dataset = build_coco_panoptic(split, args) + elif args.dataset == 'actiongenome': + dataset = build_actiongenome(split, args) + elif args.dataset == 'vidhoi': + dataset = build_vidhoi(split, args) + else: + raise ValueError(f'dataset {args.dataset} not supported') + + return dataset diff --git a/src/trackformer/datasets/actiongenome.py b/src/trackformer/datasets/actiongenome.py new file mode 100644 index 0000000..4fcc7d7 --- /dev/null +++ b/src/trackformer/datasets/actiongenome.py @@ -0,0 +1,87 @@ +import numpy as np +from .coco import CocoDetection, make_coco_transforms +from pathlib import Path +import random +import copy + +class ActionGenome(CocoDetection): + + def __init__(self, img_folder, ann_file, transforms, return_masks, + prev_frame=False, prev_frame_rnd_augs=0.0, norm_transform=None, clip_length=None): + super(ActionGenome, self).__init__( + img_folder, ann_file, transforms, return_masks, False, + norm_transform, prev_frame, prev_frame_rnd_augs, clip_length=clip_length, dataset_name='actiongenome') + + def _add_frame_to_target(self, image_id, random_state): + random.setstate(random_state) + frame_img, frame_target = self._getitem_from_id(image_id) + frame_img, frame_target = self._norm_transforms(frame_img, frame_target) + + return frame_img, frame_target + + def sequence_infos(self): + seqs = self.coco.dataset['sequences'] + startend_image_ids = self.coco.dataset['sequence_startend_image_ids'] + startend_idx = [(self.ids.index(se[0]), self.ids.index(se[1])) for se in startend_image_ids] + return seqs, startend_idx + + def real_interval_to_prev_frame(self, org_img_id): + img_info = self.coco.imgs[org_img_id] + if img_info['id'] == img_info['first_frame_image_id']: + return 0 + else: + real_img_frame_idx = img_info['file_name'][-10:-4] + prev_img_id = self.ids[self.ids.index(org_img_id)-1] + prev_img_frame_idx = self.coco.imgs[prev_img_id]['file_name'][-10:-4] + return int(real_img_frame_idx) - int(prev_img_frame_idx) + + def __getitem__(self, idx): + random_state = random.getstate() + + img, target = self._getitem_from_id(idx) + img, target = self._norm_transforms(img, target) + + if self.clip_mode: + org_img_id = self.ids[idx] + org_img_info = self.coco.imgs[org_img_id] + frame_id = org_img_info['frame_id'] + start_id = self.ids.index(org_img_info['first_frame_image_id']) + prev_image_ids = np.sort((frame_id - np.arange(0, frame_id+1))[1:self.clip_length]) + start_id + + prev_frame_imgs, prev_frame_targets = [], [] + for prev_image_id in prev_image_ids: + frame_img, frame_target = self._add_frame_to_target(prev_image_id, random_state) + prev_frame_imgs.append(frame_img) + prev_frame_targets.append(frame_target) + + # compose clip + append_num = self.clip_length - len(prev_frame_imgs) + img = prev_frame_imgs + [img.clone() for _ in range(append_num)] + target = prev_frame_targets + [copy.deepcopy(target) for _ in range(append_num)] + + return img, target + + +def build_actiongenome(image_set, args): + root = Path(args.actiongenome_path) + assert root.exists(), f'provided ActionGenome path {root} does not exist' + + split = getattr(args, f"{image_set}_split") + + img_folder = f"{root}/frames" + ann_file = root / f"{split}_cocofmt.json" + + transforms, norm_transforms = make_coco_transforms( + image_set, args.img_transform, no_crop=True) + + dataset = ActionGenome( + img_folder, ann_file, + transforms=transforms, + norm_transform=norm_transforms, + return_masks=args.masks, + prev_frame=args.tracking, + prev_frame_rnd_augs=args.track_prev_frame_rnd_augs, + clip_length=args.clip_length + ) + + return dataset diff --git a/src/trackformer/datasets/actiongenome_eval.py b/src/trackformer/datasets/actiongenome_eval.py new file mode 100644 index 0000000..1d6650c --- /dev/null +++ b/src/trackformer/datasets/actiongenome_eval.py @@ -0,0 +1,302 @@ +## refer https://github.com/yrcong/STTran + +import torch +import torch.nn as nn +import numpy as np +from functools import reduce +from ..util import box_ops +from ..util import misc as utils + +class BasicSceneGraphEvaluator: + def __init__(self, mode, iou_threshold=0.5, constraint=False, dataset='actiongenome'): + self.dataset = dataset + self.result_dict = {} + self.mode = mode + self.result_dict[self.mode + '_recall'] = {10: [], 20: [], 50: [], 100: []} + self.constraint = constraint + self.iou_threshold = iou_threshold + + def reset_result(self): + self.result_dict[self.mode + '_recall'] = {10: [], 20: [], 50: [], 100: []} + + def print_stats(self): + print(f'======================{self.mode} (Constraint={self.constraint})============================') + for k, v in self.result_dict[self.mode + '_recall'].items(): + print('R@%i: %f' % (k, np.mean(v))) + print(f"#images = {len(v)}") + + def evaluate_scene_graph(self, gt, outputs, box_preds): + '''collect the groundtruth and prediction''' + + pred_top_rel_pairs = [] + for idx, frame_gt in enumerate(gt): + frame_box_pred = box_preds[frame_gt['image_id'].item()] + + # generate ground truth + boxes = box_ops.box_cxcywh_to_xyxy(frame_gt['boxes']) + img_h, img_w = frame_gt['orig_size'] + scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).to(boxes.device) + boxes = boxes * scale_fct + + gt_boxes = boxes.cpu().numpy().astype(float) + gt_classes = frame_gt['labels'].cpu().numpy() + gt_relations = frame_gt['relation_map'].nonzero().cpu().numpy() + + # relation prediction + pred_boxes = frame_box_pred['boxes'].cpu().numpy() + pred_classes = frame_box_pred['labels'].cpu().numpy() + pred_obj_scores = frame_box_pred['scores'].cpu().numpy() + + rel_pairs = outputs['pred_rel_pairs'][idx].cpu().numpy() + predicate_scores = outputs['pred_relations'][idx].sigmoid() + triplet_scores = predicate_scores # * outputs['pred_relation_exists'][idx].sigmoid().unsqueeze(-1) # * frame_box_pred['scores'][rel_pairs[:,0]].unsqueeze(1) * frame_box_pred['scores'][rel_pairs[:,1]].unsqueeze(1) + # triplet_scores = predicate_scores * frame_box_pred['scores'][rel_pairs[:,0]].unsqueeze(1) * frame_box_pred['scores'][rel_pairs[:,1]].unsqueeze(1) + + if self.constraint: # follow STTran, only for AG evaluation + attention_scores, attention_rel_inds = triplet_scores[:, :3].max(-1) + spatial_scores, spatial_rel_inds = triplet_scores[:, 3:9].max(-1); spatial_rel_inds += 3 + contacting_scores, contacting_rel_inds = triplet_scores[:, 9:].max(-1); contacting_rel_inds +=9 + all_rel_inds = torch.cat([torch.arange(len(triplet_scores))] * 3, dim=0) + all_scores = torch.cat([attention_scores, spatial_scores, contacting_scores], dim=0) + all_predicates = torch.cat([attention_rel_inds, spatial_rel_inds, contacting_rel_inds], dim=0) + + rel_scores, perm = all_scores.sort(descending=True) + pred_rels = np.column_stack([rel_pairs[all_rel_inds[perm]], all_predicates[perm].cpu().numpy()]) + rel_scores = rel_scores.cpu().numpy() + else: + score_inds = argsort_desc(triplet_scores.cpu().numpy())[:100] + pred_rels = np.column_stack([rel_pairs[score_inds[:, 0]], score_inds[:, 1]]) + rel_scores = triplet_scores.cpu().numpy()[score_inds[:, 0], score_inds[:, 1]] + + # # groundtruths as fake predictions + # pred_boxes = gt_boxes + # pred_classes = gt_classes + # pred_obj_scores = np.ones(len(pred_classes)) + # pred_rels = gt_relations + # rel_scores = np.ones(len(pred_rels)) + + pred_top_rel_pairs.append(pred_rels) + ################ evaluation ################ + if len(gt_relations) == 0: continue + pred_to_gt, pred_5ples, rel_scores = evaluate_recall( + gt_relations, gt_boxes, gt_classes, + pred_rels, pred_boxes, pred_classes, rel_scores, pred_obj_scores, + iou_thresh=self.iou_threshold) + + for k in self.result_dict[self.mode + '_recall']: + match = reduce(np.union1d, pred_to_gt[:k]) + + rec_i = float(len(match)) / float(gt_relations.shape[0]) + self.result_dict[self.mode + '_recall'][k].append(rec_i) + return pred_top_rel_pairs + + def synchronize_between_processes(self): + all_results = utils.all_gather([self.result_dict]) + metric_key = self.mode + '_recall' + + merged_result_dict = all_results[0][0] + for p in all_results[1:]: + for k, v in merged_result_dict[metric_key].items(): + v.extend(p[0][metric_key][k]) + self.result_dict = merged_result_dict + +########################### +def evaluate_recall(gt_rels, gt_boxes, gt_classes, + pred_rels, pred_boxes, pred_classes, rel_scores=None, cls_scores=None, + iou_thresh=0.5, phrdet=False): + """ + Evaluates the recall + :param gt_rels: [#gt_rel, 3] array of GT relations + :param gt_boxes: [#gt_box, 4] array of GT boxes + :param gt_classes: [#gt_box] array of GT classes + :param pred_rels: [#pred_rel, 3] array of pred rels. Assumed these are in sorted order + and refer to IDs in pred classes / pred boxes + (id0, id1, rel) + :param pred_boxes: [#pred_box, 4] array of pred boxes + :param pred_classes: [#pred_box] array of predicted classes for these boxes + :return: pred_to_gt: Matching from predicate to GT + pred_5ples: the predicted (id0, id1, cls0, cls1, rel) + rel_scores: [cls_0score, cls1_score, relscore] + """ + if pred_rels.size == 0: + return [[]], np.zeros((0,5)), np.zeros(0) + + num_gt_boxes = gt_boxes.shape[0] + num_gt_relations = gt_rels.shape[0] + assert num_gt_relations != 0 + + gt_triplets, gt_triplet_boxes, _ = _triplet(gt_rels[:, 2], + gt_rels[:, :2], + gt_classes, + gt_boxes) + num_boxes = pred_boxes.shape[0] + assert pred_rels[:,:2].max() < pred_classes.shape[0] + + # Exclude self rels + # assert np.all(pred_rels[:,0] != pred_rels[:,ĺeftright]) + #assert np.all(pred_rels[:,2] > 0) + + pred_triplets, pred_triplet_boxes, relation_scores = \ + _triplet(pred_rels[:,2], pred_rels[:,:2], pred_classes, pred_boxes, + rel_scores, cls_scores) + + sorted_scores = relation_scores.prod(1) + pred_triplets = pred_triplets[sorted_scores.argsort()[::-1],:] + pred_triplet_boxes = pred_triplet_boxes[sorted_scores.argsort()[::-1],:] + relation_scores = relation_scores[sorted_scores.argsort()[::-1],:] + scores_overall = relation_scores.prod(1) + + if not np.all(scores_overall[1:] <= scores_overall[:-1] + 1e-5): + print("Somehow the relations weren't sorted properly: \n{}".format(scores_overall)) + # raise ValueError("Somehow the relations werent sorted properly") + + # Compute recall. It's most efficient to match once and then do recall after + pred_to_gt = _compute_pred_matches( + gt_triplets, + pred_triplets, + gt_triplet_boxes, + pred_triplet_boxes, + iou_thresh, + phrdet=phrdet, + ) + + # Contains some extra stuff for visualization. Not needed. + pred_5ples = np.column_stack(( + pred_rels[:,:2], + pred_triplets[:, [0, 2, 1]], + )) + + return pred_to_gt, pred_5ples, relation_scores + + +def _triplet(predicates, relations, classes, boxes, + predicate_scores=None, class_scores=None): + """ + format predictions into triplets + :param predicates: A 1d numpy array of num_boxes*(num_boxes-ĺeftright) predicates, corresponding to + each pair of possibilities + :param relations: A (num_boxes*(num_boxes-ĺeftright), 2.0) array, where each row represents the boxes + in that relation + :param classes: A (num_boxes) array of the classes for each thing. + :param boxes: A (num_boxes,4) array of the bounding boxes for everything. + :param predicate_scores: A (num_boxes*(num_boxes-ĺeftright)) array of the scores for each predicate + :param class_scores: A (num_boxes) array of the likelihood for each object. + :return: Triplets: (num_relations, 3) array of class, relation, class + Triplet boxes: (num_relation, 8) array of boxes for the parts + Triplet scores: num_relation array of the scores overall for the triplets + """ + assert (predicates.shape[0] == relations.shape[0]) + + sub_ob_classes = classes[relations[:, :2]] + triplets = np.column_stack((sub_ob_classes[:, 0], predicates, sub_ob_classes[:, 1])) + triplet_boxes = np.column_stack((boxes[relations[:, 0]], boxes[relations[:, 1]])) + + triplet_scores = None + if predicate_scores is not None and class_scores is not None: + triplet_scores = np.column_stack(( + class_scores[relations[:, 0]], + class_scores[relations[:, 1]], + predicate_scores, + )) + + return triplets, triplet_boxes, triplet_scores + + +def _compute_pred_matches(gt_triplets, pred_triplets, + gt_boxes, pred_boxes, iou_thresh, phrdet=False): + """ + Given a set of predicted triplets, return the list of matching GT's for each of the + given predictions + :param gt_triplets: + :param pred_triplets: + :param gt_boxes: + :param pred_boxes: + :param iou_thresh: + :return: + """ + # This performs a matrix multiplication-esque thing between the two arrays + # Instead of summing, we want the equality, so we reduce in that way + # The rows correspond to GT triplets, columns to pred triplets + keeps = intersect_2d(gt_triplets, pred_triplets) + gt_has_match = keeps.any(1) + pred_to_gt = [[] for x in range(pred_boxes.shape[0])] + for gt_ind, gt_box, keep_inds in zip(np.where(gt_has_match)[0], + gt_boxes[gt_has_match], + keeps[gt_has_match], + ): + boxes = pred_boxes[keep_inds] + if phrdet: + # Evaluate where the union box > 0.5 + gt_box_union = gt_box.reshape((2, 4)) + gt_box_union = np.concatenate((gt_box_union.min(0)[:2], gt_box_union.max(0)[2:]), 0) + + box_union = boxes.reshape((-1, 2, 4)) + box_union = np.concatenate((box_union.min(1)[:,:2], box_union.max(1)[:,2:]), 1) + + inds = bbox_overlaps(gt_box_union[None], box_union)[0] >= iou_thresh + + else: + sub_iou = bbox_overlaps(gt_box[None,:4], boxes[:, :4])[0] + obj_iou = bbox_overlaps(gt_box[None,4:], boxes[:, 4:])[0] + + inds = (sub_iou >= iou_thresh) & (obj_iou >= iou_thresh) + + for i in np.where(keep_inds)[0][inds]: + pred_to_gt[i].append(int(gt_ind)) + return pred_to_gt + + +# From frcnn/utils/bbox.py +def bbox_overlaps(boxes, query_boxes): + """ + Parameters + ---------- + boxes: (N, 4) ndarray or tensor or variable + query_boxes: (K, 4) ndarray or tensor or variable + Returns + ------- + overlaps: (N, K) overlap between boxes and query_boxes + """ + if isinstance(boxes, np.ndarray): + boxes = torch.from_numpy(boxes) + query_boxes = torch.from_numpy(query_boxes) + out_fn = lambda x: x.numpy() # If input is ndarray, turn the overlaps back to ndarray when return + else: + out_fn = lambda x: x + + box_areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1) + query_areas = (query_boxes[:, 2] - query_boxes[:, 0] + 1) * (query_boxes[:, 3] - query_boxes[:, 1] + 1) + + iw = (torch.min(boxes[:, 2:3], query_boxes[:, 2:3].t()) - torch.max(boxes[:, 0:1], + query_boxes[:, 0:1].t()) + 1).clamp(min=0) + ih = (torch.min(boxes[:, 3:4], query_boxes[:, 3:4].t()) - torch.max(boxes[:, 1:2], + query_boxes[:, 1:2].t()) + 1).clamp(min=0) + ua = box_areas.view(-1, 1) + query_areas.view(1, -1) - iw * ih + overlaps = iw * ih / ua + return out_fn(overlaps) + +def intersect_2d(x1, x2): + """ + Given two arrays [m1, n], [m2,n], returns a [m1, m2] array where each entry is True if those + rows match. + :param x1: [m1, n] numpy array + :param x2: [m2, n] numpy array + :return: [m1, m2] bool array of the intersections + """ + if x1.shape[1] != x2.shape[1]: + raise ValueError("Input arrays must have same #columns") + + # This performs a matrix multiplication-esque thing between the two arrays + # Instead of summing, we want the equality, so we reduce in that way + res = (x1[..., None] == x2.T[None, ...]).all(1) + return res + + +def argsort_desc(scores): + """ + Returns the indices that sort scores descending in a smart way + :param scores: Numpy array of arbitrary size + :return: an array of size [numel(scores), dim(scores)] where each row is the index you'd + need to get the score. + """ + return np.column_stack(np.unravel_index(np.argsort(-scores.ravel()), scores.shape)) diff --git a/src/trackformer/datasets/coco.py b/src/trackformer/datasets/coco.py new file mode 100644 index 0000000..0768599 --- /dev/null +++ b/src/trackformer/datasets/coco.py @@ -0,0 +1,315 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +COCO dataset which returns image_id for evaluation. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py +""" +import copy +import random +from pathlib import Path + +import torch +import torch.utils.data +import torchvision +from pycocotools import mask as coco_mask + +from . import transforms as T + + +class CocoDetection(torchvision.datasets.CocoDetection): + + fields = ["labels", "area", "iscrowd", "boxes", "track_ids", "masks"] + + def __init__(self, img_folder, ann_file, transforms, return_masks, + remove_no_obj_imgs=True, norm_transforms=None, + prev_frame=False, prev_frame_rnd_augs=0.05, clip_length=None, dataset_name=None): + super(CocoDetection, self).__init__(img_folder, ann_file) + self._transforms = transforms + self._norm_transforms = norm_transforms + self.prepare = ConvertCocoPolysToMask(return_masks, dataset_name=dataset_name) + self.dataset_name = dataset_name + + if remove_no_obj_imgs: + self.ids = sorted(list(set( + [ann['image_id'] for ann in self.coco.loadAnns(self.coco.getAnnIds())]))) + + self._prev_frame = prev_frame + self._prev_frame_rnd_augs = prev_frame_rnd_augs + + self.clip_mode = isinstance(clip_length, int) + self.clip_length = clip_length + # self.ids = self.ids[:10] ## for debug + + def _getitem_from_id(self, image_id): + img, target = super(CocoDetection, self).__getitem__(image_id) + image_id = self.ids[image_id] + target = {'image_id': image_id, 'annotations': target} + if self.dataset_name == 'vidhoi': + org_image_info = self.coco.imgs[image_id] + if org_image_info['frame_key'] not in self.coco.dataset['relation_annotations'][org_image_info['video_key']]: + relation_annotations = [] + else: + relation_annotations = self.coco.dataset['relation_annotations'][org_image_info['video_key']][org_image_info['frame_key']] + target.update({ + 'image_key': f"{org_image_info['video_key']}/{org_image_info['frame_key']}", + 'relation_annotations': relation_annotations + }) + + img, target = self.prepare(img, target) + # target['track_ids'] = torch.arange(len(target['labels'])) ## !!! serious bug + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # ignore + ignore = target.pop("ignore").bool() + for field in self.fields: + if field in target: + target[f"{field}_ignore"] = target[field][ignore] + target[field] = target[field][~ignore] + + return img, target + + def __getitem__(self, idx): + img, target = self._getitem_from_id(idx) + + target['track_ids'] = torch.arange(len(target['labels'])) + + if self._prev_frame: + prev_img = img.copy() + prev_target = copy.deepcopy(target) + + orig_w, orig_h = img.size + + # prev img + w, h = prev_img.size + size = random.randint( + int((1.0 - self._prev_frame_rnd_augs) * min(w, h)), + int((1.0 + self._prev_frame_rnd_augs) * min(w, h))) + prev_img, prev_target = T.RandomResize([size])(prev_img, prev_target) + + w, h = prev_img.size + min_size = ( + int((1.0 - self._prev_frame_rnd_augs) * w), + int((1.0 - self._prev_frame_rnd_augs) * h)) + transform = T.RandomSizeCrop(min_size=min_size) + prev_img, prev_target = transform(prev_img, prev_target) + + w, h = prev_img.size + if orig_w < w: + prev_img, prev_target = T.RandomCrop((h, orig_w))(prev_img, prev_target) + else: + prev_img, prev_target = T.RandomPad(max_size=(orig_w, h))(prev_img, prev_target) + + w, h = prev_img.size + if orig_h < h: + prev_img, prev_target = T.RandomCrop((orig_h, w))(prev_img, prev_target) + else: + prev_img, prev_target = T.RandomPad(max_size=(w, orig_h))(prev_img, prev_target) + + img, target = self._norm_transforms(img, target) + + if self._prev_frame: + prev_img, prev_target = self._norm_transforms(prev_img, prev_target) + + if self.clip_mode: + img = [prev_img, img] + target = [prev_target, target] + else: + target['prev_image'] = prev_img + for k, v in prev_target.items(): + target[f'prev_{k}'] = v + + return img, target + + def write_result_files(self, *args): + pass + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + if isinstance(polygons, dict): + rles = {'size': polygons['size'], + 'counts': polygons['counts'].encode(encoding='UTF-8')} + else: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False, dataset_name=None): + self.return_masks = return_masks + self.dataset_name = dataset_name + + def __call__(self, image, org_target): + w, h = image.size + + image_id = org_target["image_id"] + image_id = torch.tensor([image_id]) + + anno = org_target["annotations"] + + anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + # x,y,w,h --> x,y,x,y + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes - 1 ## !!! 注意对标注的类别进行了移位,暂时丢掉了 __background__ + + if self.return_masks: + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + if anno and "track_id" in anno[0]: + track_ids = torch.tensor([obj["track_id"] for obj in anno]) + target["track_ids"] = track_ids[keep] + elif not len(boxes): + target["track_ids"] = torch.empty(0) + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + ignore = torch.tensor([obj["ignore"] if "ignore" in obj else 0 for obj in anno]) + + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + target["ignore"] = ignore[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + if self.dataset_name == 'actiongenome': + # !!! hardcode for actiongenome relations (single human instance) + target['relation_map'] = torch.zeros((len(boxes), len(boxes), 26)) + human_instance_id = target["labels"].tolist().index(0) # class0: person + valid_insts = [a for a, flag in zip(org_target['annotations'], keep) if flag] + for inst_id, info in enumerate(valid_insts): + if 'relationships' in info: + predicates = info['relationships'] + target['relation_map'][human_instance_id, inst_id, predicates] = 1 + elif self.dataset_name == 'vidhoi': + assert len(org_target['annotations']) == len(boxes) + tid2idx = {x['track_id']: idx for idx, x in enumerate(org_target['annotations'])} + target['relation_map'] = torch.zeros((len(boxes), len(boxes), 50)) + for rel_inst in org_target['relation_annotations']: + target['relation_map'][tid2idx[rel_inst['subject_tid']], tid2idx[rel_inst['object_tid']], rel_inst['predicate']] = 1 + + return image, target + + +def make_coco_transforms(image_set, img_transform=None, no_crop=False): + normalize = T.Compose([ + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + # default + max_size = 1333 + val_width = 800 + scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] + random_resizes = [400, 500, 600] + random_size_crop = (384, 600) + + if img_transform is not None: + scale = img_transform.max_size / max_size + max_size = img_transform.max_size + val_width = img_transform.val_width + + # scale all with respect to custom max_size + scales = [int(scale * s) for s in scales] + random_resizes = [int(scale * s) for s in random_resizes] + random_size_crop = [int(scale * s) for s in random_size_crop] + + if image_set == 'train': + option2 = [ + T.RandomResize(random_resizes), + T.RandomSizeCrop(*random_size_crop), + T.RandomResize(scales, max_size=max_size), + ] + if no_crop: option2.pop(1) # avoid removing boxes by cropping + + transforms = [ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=max_size), + T.Compose(option2) + ), + ] + elif image_set == 'val': + transforms = [ + T.RandomResize([val_width], max_size=max_size), + ] + else: + ValueError(f'unknown {image_set}') + + # transforms.append(normalize) + return T.Compose(transforms), normalize + + +def build(image_set, args, mode='instances'): + root = Path(args.coco_path) + assert root.exists(), f'provided COCO path {root} does not exist' + + # image_set is 'train' or 'val' + split = getattr(args, f"{image_set}_split") + + splits = { + "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), + "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), + } + + transforms, norm_transforms = make_coco_transforms(image_set, args.img_transform) + img_folder, ann_file = splits[split] + dataset = CocoDetection( + img_folder, + ann_file, + transforms=transforms, + norm_transforms=norm_transforms, + return_masks=args.masks, + prev_frame=args.tracking, + prev_frame_rnd_augs=args.coco_and_crowdhuman_prev_frame_rnd_augs) + + return dataset diff --git a/src/trackformer/datasets/coco_eval.py b/src/trackformer/datasets/coco_eval.py new file mode 100644 index 0000000..3de6751 --- /dev/null +++ b/src/trackformer/datasets/coco_eval.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" +import os +import contextlib +import copy +import numpy as np +import torch + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from ..util.misc import all_gather + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for prediction in predictions.values(): + prediction["labels"] += 1 + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval( + self.coco_eval[iou_type], + self.img_ids, + self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print(f"IoU metric: {iou_type}") + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + 'keypoints': keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/src/trackformer/datasets/coco_panoptic.py b/src/trackformer/datasets/coco_panoptic.py new file mode 100644 index 0000000..b24f615 --- /dev/null +++ b/src/trackformer/datasets/coco_panoptic.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import json +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +from panopticapi.utils import rgb2id +from util.box_ops import masks_to_boxes + +from .coco import make_coco_transforms + + +class CocoPanoptic: + def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): + with open(ann_file, 'r') as f: + self.coco = json.load(f) + + # sort 'images' field so that they are aligned with 'annotations' + # i.e., in alphabetical order + self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) + # sanity check + if "annotations" in self.coco: + for img, ann in zip(self.coco['images'], self.coco['annotations']): + assert img['file_name'][:-4] == ann['file_name'][:-4] + + self.img_folder = img_folder + self.ann_folder = ann_folder + self.ann_file = ann_file + self.transforms = transforms + self.return_masks = return_masks + + def __getitem__(self, idx): + ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] + img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') + ann_path = Path(self.ann_folder) / ann_info['file_name'] + + img = Image.open(img_path).convert('RGB') + w, h = img.size + if "segments_info" in ann_info: + masks = np.asarray(Image.open(ann_path), dtype=np.uint32) + masks = rgb2id(masks) + + ids = np.array([ann['id'] for ann in ann_info['segments_info']]) + masks = masks == ids[:, None, None] + + masks = torch.as_tensor(masks, dtype=torch.uint8) + labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) + + target = {} + target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) + if self.return_masks: + target['masks'] = masks + target['labels'] = labels + + target["boxes"] = masks_to_boxes(masks) + + target['size'] = torch.as_tensor([int(h), int(w)]) + target['orig_size'] = torch.as_tensor([int(h), int(w)]) + if "segments_info" in ann_info: + for name in ['iscrowd', 'area']: + target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.coco['images']) + + def get_height_and_width(self, idx): + img_info = self.coco['images'][idx] + height = img_info['height'] + width = img_info['width'] + return height, width + + +def build(image_set, args): + img_folder_root = Path(args.coco_path) + ann_folder_root = Path(args.coco_panoptic_path) + assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' + assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' + mode = 'panoptic' + PATHS = { + "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), + "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), + } + + img_folder, ann_file = PATHS[image_set] + img_folder_path = img_folder_root / img_folder + ann_folder = ann_folder_root / f'{mode}_{img_folder}' + ann_file = ann_folder_root / ann_file + + dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, + transforms=make_coco_transforms(image_set), return_masks=args.masks) + + return dataset diff --git a/src/trackformer/datasets/crowdhuman.py b/src/trackformer/datasets/crowdhuman.py new file mode 100644 index 0000000..7fb40e1 --- /dev/null +++ b/src/trackformer/datasets/crowdhuman.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +CrowdHuman dataset with tracking training augmentations. +""" +from pathlib import Path + +from .coco import CocoDetection, make_coco_transforms + + +def build_crowdhuman(image_set, args): + root = Path(args.crowdhuman_path) + assert root.exists(), f'provided COCO path {root} does not exist' + + split = getattr(args, f"{image_set}_split") + + img_folder = root / split + ann_file = root / f'annotations/{split}.json' + + transforms, norm_transforms = make_coco_transforms( + image_set, args.img_transform) + dataset = CocoDetection( + img_folder, + ann_file, + transforms=transforms, + norm_transforms=norm_transforms, + return_masks=args.masks, + prev_frame=args.tracking, + prev_frame_rnd_augs=args.coco_and_crowdhuman_prev_frame_rnd_augs, + clip_length=args.clip_length) + + return dataset diff --git a/src/trackformer/datasets/mot.py b/src/trackformer/datasets/mot.py new file mode 100644 index 0000000..1973440 --- /dev/null +++ b/src/trackformer/datasets/mot.py @@ -0,0 +1,234 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +MOT dataset with tracking training augmentations. +""" +import bisect +import copy +import csv +import os +import random +from pathlib import Path + +import torch + +from . import transforms as T +from .coco import CocoDetection, make_coco_transforms +from .crowdhuman import build_crowdhuman +import numpy as np + + +class MOT(CocoDetection): + + def __init__(self, img_folder, ann_file, transforms, return_masks, + prev_frame=False, prev_frame_range=None, prev_frame_rnd_augs=0.0, norm_transform=None, clip_length=None): + super(MOT, self).__init__( + img_folder, ann_file, transforms, return_masks, False, + norm_transform, prev_frame, prev_frame_rnd_augs, clip_length=clip_length) + + self._prev_frame_range = prev_frame_range + + def sequence_infos(self): + seqs = self.coco.dataset['sequences'] + frame_names = [img['file_name'].split('_')[0] for img in self.coco.dataset['images']] + + startend_idx = [] + for seq in seqs: + all_inds = [i for i, n in enumerate(frame_names) if seq==n] + startend_idx.append((min(all_inds), max(all_inds))) + + return seqs, startend_idx + + @property + def sequences(self): + return self.coco.dataset['sequences'] + + @property + def frame_range(self): + if 'frame_range' in self.coco.dataset: + return self.coco.dataset['frame_range'] + else: + return {'start': 0, 'end': 1.0} + + def _add_frame_to_target(self, target, image_id, random_state, key_prefix): + random.setstate(random_state) + frame_img, frame_target = self._getitem_from_id(image_id) + + # # random jitter ## for debugging, remove aug temporally + # if self._prev_frame_rnd_augs and random.uniform(0, 1) < 0.5: + # # prev img + # orig_w, orig_h = frame_img.size + # + # width, height = frame_img.size + # size = random.randint( + # int((1.0 - self._prev_frame_rnd_augs) * min(width, height)), + # int((1.0 + self._prev_frame_rnd_augs) * min(width, height))) + # frame_img, frame_target = T.RandomResize([size])(frame_img, frame_target) + # + # width, height = frame_img.size + # min_size = ( + # int((1.0 - self._prev_frame_rnd_augs) * width), + # int((1.0 - self._prev_frame_rnd_augs) * height)) + # transform = T.RandomSizeCrop(min_size=min_size) + # frame_img, frame_target = transform(frame_img, frame_target) + # + # width, height = frame_img.size + # if orig_w < width: + # frame_img, frame_target = T.RandomCrop((height, orig_w))(frame_img, frame_target) + # else: + # frame_img, frame_target = T.RandomPad( + # max_size=(orig_w, height))(frame_img, frame_target) + # + # width, height = frame_img.size + # if orig_h < height: + # frame_img, frame_target = T.RandomCrop((orig_h, width))(frame_img, frame_target) + # else: + # frame_img, frame_target = T.RandomPad( + # max_size=(width, orig_h))(frame_img, frame_target) + # + # assert(frame_img.size[0] == orig_w and frame_img.size[1] == orig_h) + + frame_img, frame_target = self._norm_transforms(frame_img, frame_target) + + if self.clip_mode: + return frame_img, frame_target + else: + target[f'{key_prefix}_image'] = frame_img + for k, v in frame_target.items(): + target[f'{key_prefix}_{k}'] = v + + def seq_length(self, idx): + return self.coco.imgs[idx]['seq_length'] + + def sample_weight(self, idx): + return 1.0 / self.seq_length(idx) + + def __getitem__(self, idx): + random_state = random.getstate() + + img, target = self._getitem_from_id(idx) + img, target = self._norm_transforms(img, target) + + if self._prev_frame: + frame_id = self.coco.imgs[idx]['frame_id'] + + if self.clip_mode: + # assert self.clip_length >= 2 + sample_interval = np.random.randint(1, self._prev_frame_range+1) + prev_image_ids = np.sort((frame_id - np.arange(0, frame_id+1, sample_interval))[1:self.clip_length]) + self.coco.imgs[idx]['first_frame_image_id'] + + prev_frame_imgs, prev_frame_targets = [], [] + for prev_image_id in prev_image_ids: + frame_img, frame_target = self._add_frame_to_target(None, prev_image_id, random_state, 'prev') + prev_frame_imgs.append(frame_img) + prev_frame_targets.append(frame_target) + + # compose clip + append_num = self.clip_length - len(prev_frame_imgs) + img = prev_frame_imgs + [img.clone() for _ in range(append_num)] + target = prev_frame_targets + [copy.deepcopy(target) for _ in range(append_num)] + else: # originally only one previous frame + prev_frame_id = random.randint( + max(0, frame_id - self._prev_frame_range), + # min(frame_id + self._prev_frame_range, self.seq_length(idx) - 1)) ## 有可能 prev_frame_id > 当前帧 + min(frame_id, self.seq_length(idx) - 1)) + prev_image_id = self.coco.imgs[idx]['first_frame_image_id'] + prev_frame_id + self._add_frame_to_target(target, prev_image_id, random_state, 'prev') + + return img, target + + def write_result_files(self, results, output_dir): + """Write the detections in the format for the MOT17Det sumbission + + Each file contains these lines: + , , , , , , , , , + + """ + + files = {} + for image_id, res in results.items(): + img = self.coco.loadImgs(image_id)[0] + file_name_without_ext = os.path.splitext(img['file_name'])[0] + seq_name, frame = file_name_without_ext.split('_') + frame = int(frame) + + outfile = os.path.join(output_dir, f"{seq_name}.txt") + + # check if out in keys and create empty list if not + if outfile not in files.keys(): + files[outfile] = [] + + for box, score in zip(res['boxes'], res['scores']): + if score <= 0.7: + continue + x1 = box[0].item() + y1 = box[1].item() + x2 = box[2].item() + y2 = box[3].item() + files[outfile].append( + [frame, -1, x1, y1, x2 - x1, y2 - y1, score.item(), -1, -1, -1]) + + for k, v in files.items(): + with open(k, "w") as of: + writer = csv.writer(of, delimiter=',') + for d in v: + writer.writerow(d) + + +class WeightedConcatDataset(torch.utils.data.ConcatDataset): + + def sample_weight(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + if hasattr(self.datasets[dataset_idx], 'sample_weight'): + return self.datasets[dataset_idx].sample_weight(sample_idx) + else: + return 1 / len(self.datasets[dataset_idx]) + + +def build_mot(image_set, args): + root = Path(args.mot_path) + assert root.exists(), f'provided MOT17Det path {root} does not exist' + + split = getattr(args, f"{image_set}_split") + + img_folder = root / split + ann_file = root / f"annotations/{split}.json" + + transforms, norm_transforms = make_coco_transforms( + image_set, args.img_transform) + + dataset = MOT( + img_folder, ann_file, + transforms=transforms, + norm_transform=norm_transforms, + return_masks=args.masks, + prev_frame=args.tracking, + prev_frame_range=args.track_prev_frame_range, + prev_frame_rnd_augs=args.track_prev_frame_rnd_augs, + clip_length=args.clip_length + ) + + return dataset + + +def build_mot_crowdhuman(image_set, args): + if image_set == 'train': + args_crowdhuman = copy.deepcopy(args) + args_crowdhuman.train_split = args.crowdhuman_train_split + + crowdhuman_dataset = build_crowdhuman('train', args_crowdhuman) + + if getattr(args, f"{image_set}_split") is None: + return crowdhuman_dataset + + dataset = build_mot(image_set, args) + + if image_set == 'train': + dataset = torch.utils.data.ConcatDataset( + [dataset, crowdhuman_dataset]) + + return dataset diff --git a/src/trackformer/datasets/panoptic_eval.py b/src/trackformer/datasets/panoptic_eval.py new file mode 100644 index 0000000..dec51bc --- /dev/null +++ b/src/trackformer/datasets/panoptic_eval.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import json +import os + +from ..util import misc as utils + +try: + from panopticapi.evaluation import pq_compute +except ImportError: + pass + + +class PanopticEvaluator(object): + def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): + self.gt_json = ann_file + self.gt_folder = ann_folder + if utils.is_main_process(): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + self.output_dir = output_dir + self.predictions = [] + + def update(self, predictions): + for p in predictions: + with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: + f.write(p.pop("png_string")) + + self.predictions += predictions + + def synchronize_between_processes(self): + all_predictions = utils.all_gather(self.predictions) + merged_predictions = [] + for p in all_predictions: + merged_predictions += p + self.predictions = merged_predictions + + def summarize(self): + if utils.is_main_process(): + json_data = {"annotations": self.predictions} + predictions_json = os.path.join(self.output_dir, "predictions.json") + with open(predictions_json, "w") as f: + f.write(json.dumps(json_data)) + return pq_compute( + self.gt_json, predictions_json, + gt_folder=self.gt_folder, pred_folder=self.output_dir) + return None diff --git a/src/trackformer/datasets/tracking/__init__.py b/src/trackformer/datasets/tracking/__init__.py new file mode 100644 index 0000000..c3bd70b --- /dev/null +++ b/src/trackformer/datasets/tracking/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Submodule interface. +""" +from .factory import TrackDatasetFactory diff --git a/src/trackformer/datasets/tracking/demo_sequence.py b/src/trackformer/datasets/tracking/demo_sequence.py new file mode 100644 index 0000000..56f7b6b --- /dev/null +++ b/src/trackformer/datasets/tracking/demo_sequence.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +MOT17 sequence dataset. +""" +import configparser +import csv +import os +from pathlib import Path +import os.path as osp +from argparse import Namespace +from typing import Optional, Tuple, List + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from ..coco import make_coco_transforms +from ..transforms import Compose + + +class DemoSequence(Dataset): + """DemoSequence (MOT17) Dataset. + """ + + def __init__(self, root_dir: str = 'data', img_transform: Namespace = None) -> None: + """ + Args: + seq_name (string): Sequence to take + vis_threshold (float): Threshold of visibility of persons + above which they are selected + """ + super().__init__() + + self._data_dir = Path(root_dir) + assert self._data_dir.is_dir(), f'data_root_dir:{root_dir} does not exist.' + + self.transforms = Compose(make_coco_transforms('val', img_transform)) + + self.data = self._sequence() + self.no_gt = True + + def __len__(self) -> int: + return len(self.data) + + def __str__(self) -> str: + return self._data_dir.name + + def __getitem__(self, idx: int) -> dict: + """Return the ith image converted to blob""" + data = self.data[idx] + img = Image.open(data['im_path']).convert("RGB") + width_orig, height_orig = img.size + + img, _ = self.transforms(img) + width, height = img.size(2), img.size(1) + + sample = {} + sample['img'] = img + sample['img_path'] = data['im_path'] + sample['dets'] = torch.tensor([]) + sample['orig_size'] = torch.as_tensor([int(height_orig), int(width_orig)]) + sample['size'] = torch.as_tensor([int(height), int(width)]) + + return sample + + def _sequence(self) -> List[dict]: + total = [] + for filename in sorted(os.listdir(self._data_dir)): + extension = os.path.splitext(filename)[1] + if extension in ['.png', '.jpg']: + total.append({'im_path': osp.join(self._data_dir, filename)}) + + return total + + def load_results(self, results_dir: str) -> dict: + return {} + + def write_results(self, results: dict, output_dir: str) -> None: + """Write the tracks in the format for MOT16/MOT17 sumbission + + results: dictionary with 1 dictionary for every track with + {..., i:np.array([x1,y1,x2,y2]), ...} at key track_num + + Each file contains these lines: + , , , , , , , , , + """ + + # format_str = "{}, -1, {}, {}, {}, {}, {}, -1, -1, -1" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + result_file_path = osp.join(output_dir, self._data_dir.name) + + with open(result_file_path, "w") as r_file: + writer = csv.writer(r_file, delimiter=',') + + for i, track in results.items(): + for frame, data in track.items(): + x1 = data['bbox'][0] + y1 = data['bbox'][1] + x2 = data['bbox'][2] + y2 = data['bbox'][3] + + writer.writerow([ + frame + 1, + i + 1, + x1 + 1, + y1 + 1, + x2 - x1 + 1, + y2 - y1 + 1, + -1, -1, -1, -1]) diff --git a/src/trackformer/datasets/tracking/factory.py b/src/trackformer/datasets/tracking/factory.py new file mode 100644 index 0000000..869231d --- /dev/null +++ b/src/trackformer/datasets/tracking/factory.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Factory of tracking datasets. +""" +from typing import Union + +from torch.utils.data import ConcatDataset + +from .mot_wrapper import MOT17Wrapper, MOTS20Wrapper +from .demo_sequence import DemoSequence + +DATASETS = {} + +# Fill all available datasets, change here to modify / add new datasets. +for split in ['TRAIN', 'TEST', 'ALL', '01', '02', '03', '04', '05', + '06', '07', '08', '09', '10', '11', '12', '13', '14']: + for dets in ['DPM', 'FRCNN', 'SDP', 'ALL']: + name = f'MOT17-{split}' + if dets: + name = f"{name}-{dets}" + DATASETS[name] = ( + lambda kwargs, split=split, dets=dets: MOT17Wrapper(split, dets, **kwargs)) + + +for split in ['TRAIN', 'TEST', 'ALL', '01', '02', '05', '06', '07', '09', '11', '12']: + name = f'MOTS20-{split}' + DATASETS[name] = ( + lambda kwargs, split=split: MOTS20Wrapper(split, **kwargs)) + +DATASETS['DEMO'] = (lambda kwargs: [DemoSequence(**kwargs), ]) + + +class TrackDatasetFactory: + """A central class to manage the individual dataset loaders. + + This class contains the datasets. Once initialized the individual parts (e.g. sequences) + can be accessed. + """ + + def __init__(self, datasets: Union[str, list], **kwargs) -> None: + """Initialize the corresponding dataloader. + + Keyword arguments: + datasets -- the name of the dataset or list of dataset names + kwargs -- arguments used to call the datasets + """ + if isinstance(datasets, str): + datasets = [datasets] + + self._data = None + for dataset in datasets: + assert dataset in DATASETS, f"[!] Dataset not found: {dataset}" + + if self._data is None: + self._data = DATASETS[dataset](kwargs) + else: + self._data = ConcatDataset([self._data, DATASETS[dataset](kwargs)]) + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, idx: int): + return self._data[idx] diff --git a/src/trackformer/datasets/tracking/mot17_sequence.py b/src/trackformer/datasets/tracking/mot17_sequence.py new file mode 100644 index 0000000..2e3854a --- /dev/null +++ b/src/trackformer/datasets/tracking/mot17_sequence.py @@ -0,0 +1,272 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +MOT17 sequence dataset. +""" +import configparser +import csv +import os +import os.path as osp +from argparse import Namespace +from typing import Optional, Tuple, List + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from ..coco import make_coco_transforms +from ..transforms import Compose + + +class MOT17Sequence(Dataset): + """Multiple Object Tracking (MOT17) Dataset. + + This dataloader is designed so that it can handle only one sequence, + if more have to be handled one should inherit from this class. + """ + data_folder = 'MOT17' + + def __init__(self, root_dir: str = 'data', seq_name: Optional[str] = None, + dets: str = '', vis_threshold: float = 0.0, img_transform: Namespace = None) -> None: + """ + Args: + seq_name (string): Sequence to take + vis_threshold (float): Threshold of visibility of persons + above which they are selected + """ + super().__init__() + + self._seq_name = seq_name + self._dets = dets + self._vis_threshold = vis_threshold + + self._data_dir = osp.join(root_dir, self.data_folder) + + self._train_folders = os.listdir(os.path.join(self._data_dir, 'train')) + self._test_folders = os.listdir(os.path.join(self._data_dir, 'test')) + + self.transforms = Compose(make_coco_transforms('val', img_transform)) + + self.data = [] + self.no_gt = True + if seq_name is not None: + full_seq_name = seq_name + if self._dets is not None: + full_seq_name = f"{seq_name}-{dets}" + assert full_seq_name in self._train_folders or full_seq_name in self._test_folders, \ + 'Image set does not exist: {}'.format(full_seq_name) + + self.data = self._sequence() + self.no_gt = not osp.exists(self.get_gt_file_path()) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> dict: + """Return the ith image converted to blob""" + data = self.data[idx] + img = Image.open(data['im_path']).convert("RGB") + width_orig, height_orig = img.size + + img, _ = self.transforms(img) + width, height = img.size(2), img.size(1) + + sample = {} + sample['img'] = img + sample['dets'] = torch.tensor([det[:4] for det in data['dets']]) + sample['img_path'] = data['im_path'] + sample['gt'] = data['gt'] + sample['vis'] = data['vis'] + sample['orig_size'] = torch.as_tensor([int(height_orig), int(width_orig)]) + sample['size'] = torch.as_tensor([int(height), int(width)]) + + return sample + + def _sequence(self) -> List[dict]: + # public detections + dets = {i: [] for i in range(1, self.seq_length + 1)} + det_file = self.get_det_file_path() + + if osp.exists(det_file): + with open(det_file, "r") as inf: + reader = csv.reader(inf, delimiter=',') + for row in reader: + x1 = float(row[2]) - 1 + y1 = float(row[3]) - 1 + # This -1 accounts for the width (width of 1 x1=x2) + x2 = x1 + float(row[4]) - 1 + y2 = y1 + float(row[5]) - 1 + score = float(row[6]) + bbox = np.array([x1, y1, x2, y2, score], dtype=np.float32) + dets[int(float(row[0]))].append(bbox) + + # accumulate total + img_dir = osp.join( + self.get_seq_path(), + self.config['Sequence']['imDir']) + + boxes, visibility = self.get_track_boxes_and_visbility() + + total = [ + {'gt': boxes[i], + 'im_path': osp.join(img_dir, f"{i:06d}.jpg"), + 'vis': visibility[i], + 'dets': dets[i]} + for i in range(1, self.seq_length + 1)] + + return total + + def get_track_boxes_and_visbility(self) -> Tuple[dict, dict]: + """ Load ground truth boxes and their visibility.""" + boxes = {} + visibility = {} + + for i in range(1, self.seq_length + 1): + boxes[i] = {} + visibility[i] = {} + + gt_file = self.get_gt_file_path() + if not osp.exists(gt_file): + return boxes, visibility + + with open(gt_file, "r") as inf: + reader = csv.reader(inf, delimiter=',') + for row in reader: + # class person, certainity 1 + if int(row[6]) == 1 and int(row[7]) == 1 and float(row[8]) >= self._vis_threshold: + # Make pixel indexes 0-based, should already be 0-based (or not) + x1 = int(row[2]) - 1 + y1 = int(row[3]) - 1 + # This -1 accounts for the width (width of 1 x1=x2) + x2 = x1 + int(row[4]) - 1 + y2 = y1 + int(row[5]) - 1 + bbox = np.array([x1, y1, x2, y2], dtype=np.float32) + + frame_id = int(row[0]) + track_id = int(row[1]) + + boxes[frame_id][track_id] = bbox + visibility[frame_id][track_id] = float(row[8]) + + return boxes, visibility + + def get_seq_path(self) -> str: + """ Return directory path of sequence. """ + full_seq_name = self._seq_name + if self._dets is not None: + full_seq_name = f"{self._seq_name}-{self._dets}" + + if full_seq_name in self._train_folders: + return osp.join(self._data_dir, 'train', full_seq_name) + else: + return osp.join(self._data_dir, 'test', full_seq_name) + + def get_config_file_path(self) -> str: + """ Return config file of sequence. """ + return osp.join(self.get_seq_path(), 'seqinfo.ini') + + def get_gt_file_path(self) -> str: + """ Return ground truth file of sequence. """ + return osp.join(self.get_seq_path(), 'gt', 'gt.txt') + + def get_det_file_path(self) -> str: + """ Return public detections file of sequence. """ + if self._dets is None: + return "" + + return osp.join(self.get_seq_path(), 'det', 'det.txt') + + @property + def config(self) -> dict: + """ Return config of sequence. """ + config_file = self.get_config_file_path() + + assert osp.exists(config_file), \ + f'Config file does not exist: {config_file}' + + config = configparser.ConfigParser() + config.read(config_file) + return config + + @property + def seq_length(self) -> int: + """ Return sequence length, i.e, number of frames. """ + return int(self.config['Sequence']['seqLength']) + + def __str__(self) -> str: + return f"{self._seq_name}-{self._dets}" + + @property + def results_file_name(self) -> str: + """ Generate file name of results file. """ + assert self._seq_name is not None, "[!] No seq_name, probably using combined database" + + if self._dets is None: + return f"{self._seq_name}.txt" + + return f"{self}.txt" + + def write_results(self, results: dict, output_dir: str) -> None: + """Write the tracks in the format for MOT16/MOT17 sumbission + + results: dictionary with 1 dictionary for every track with + {..., i:np.array([x1,y1,x2,y2]), ...} at key track_num + + Each file contains these lines: + , , , , , , , , , + """ + + # format_str = "{}, -1, {}, {}, {}, {}, {}, -1, -1, -1" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + result_file_path = osp.join(output_dir, self.results_file_name) + + with open(result_file_path, "w") as r_file: + writer = csv.writer(r_file, delimiter=',') + + for i, track in results.items(): + for frame, data in track.items(): + x1 = data['bbox'][0] + y1 = data['bbox'][1] + x2 = data['bbox'][2] + y2 = data['bbox'][3] + + writer.writerow([ + frame + 1, + i + 1, + x1 + 1, + y1 + 1, + x2 - x1 + 1, + y2 - y1 + 1, + -1, -1, -1, -1]) + + def load_results(self, results_dir: str) -> dict: + results = {} + if results_dir is None: + return results + + file_path = osp.join(results_dir, self.results_file_name) + + if not os.path.isfile(file_path): + return results + + with open(file_path, "r") as file: + csv_reader = csv.reader(file, delimiter=',') + + for row in csv_reader: + frame_id, track_id = int(row[0]) - 1, int(row[1]) - 1 + + if track_id not in results: + results[track_id] = {} + + x1 = float(row[2]) - 1 + y1 = float(row[3]) - 1 + x2 = float(row[4]) - 1 + x1 + y2 = float(row[5]) - 1 + y1 + + results[track_id][frame_id] = {} + results[track_id][frame_id]['bbox'] = [x1, y1, x2, y2] + results[track_id][frame_id]['score'] = 1.0 + + return results diff --git a/src/trackformer/datasets/tracking/mot_wrapper.py b/src/trackformer/datasets/tracking/mot_wrapper.py new file mode 100644 index 0000000..91aee87 --- /dev/null +++ b/src/trackformer/datasets/tracking/mot_wrapper.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +MOT wrapper which combines sequences to a dataset. +""" +from torch.utils.data import Dataset + +from .mot17_sequence import MOT17Sequence +from .mots20_sequence import MOTS20Sequence + + +class MOT17Wrapper(Dataset): + """A Wrapper for the MOT_Sequence class to return multiple sequences.""" + + def __init__(self, split: str, dets: str, **kwargs) -> None: + """Initliazes all subset of the dataset. + + Keyword arguments: + split -- the split of the dataset to use + kwargs -- kwargs for the MOT17Sequence dataset + """ + train_sequences = [ + 'MOT17-02', 'MOT17-04', 'MOT17-05', 'MOT17-09', + 'MOT17-10', 'MOT17-11', 'MOT17-13'] + test_sequences = [ + 'MOT17-01', 'MOT17-03', 'MOT17-06', 'MOT17-07', + 'MOT17-08', 'MOT17-12', 'MOT17-14'] + + if split == "TRAIN": + sequences = train_sequences + elif split == "TEST": + sequences = test_sequences + elif split == "ALL": + sequences = train_sequences + test_sequences + sequences = sorted(sequences) + elif f"MOT17-{split}" in train_sequences + test_sequences: + sequences = [f"MOT17-{split}"] + else: + raise NotImplementedError("MOT17 split not available.") + + self._data = [] + for seq in sequences: + if dets == 'ALL': + self._data.append(MOT17Sequence(seq_name=seq, dets='DPM', **kwargs)) + self._data.append(MOT17Sequence(seq_name=seq, dets='FRCNN', **kwargs)) + self._data.append(MOT17Sequence(seq_name=seq, dets='SDP', **kwargs)) + else: + self._data.append(MOT17Sequence(seq_name=seq, dets=dets, **kwargs)) + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, idx: int): + return self._data[idx] + + +class MOTS20Wrapper(MOT17Wrapper): + """A Wrapper for the MOT_Sequence class to return multiple sequences.""" + + def __init__(self, split: str, **kwargs) -> None: + """Initliazes all subset of the dataset. + + Keyword arguments: + split -- the split of the dataset to use + kwargs -- kwargs for the MOTS20Sequence dataset + """ + train_sequences = ['MOTS20-02', 'MOTS20-05', 'MOTS20-09', 'MOTS20-11'] + test_sequences = ['MOTS20-01', 'MOTS20-06', 'MOTS20-07', 'MOTS20-12'] + + if split == "TRAIN": + sequences = train_sequences + elif split == "TEST": + sequences = test_sequences + elif split == "ALL": + sequences = train_sequences + test_sequences + sequences = sorted(sequences) + elif f"MOTS20-{split}" in train_sequences + test_sequences: + sequences = [f"MOTS20-{split}"] + else: + raise NotImplementedError("MOTS20 split not available.") + + self._data = [] + for seq in sequences: + self._data.append(MOTS20Sequence(seq_name=seq, **kwargs)) diff --git a/src/trackformer/datasets/tracking/mots20_sequence.py b/src/trackformer/datasets/tracking/mots20_sequence.py new file mode 100644 index 0000000..7c43016 --- /dev/null +++ b/src/trackformer/datasets/tracking/mots20_sequence.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +MOTS20 sequence dataset. +""" +import csv +import os +import os.path as osp +from argparse import Namespace +from typing import Optional, Tuple + +import numpy as np +import pycocotools.mask as rletools + +from .mot17_sequence import MOT17Sequence + + +class MOTS20Sequence(MOT17Sequence): + """Multiple Object and Segmentation Tracking (MOTS20) Dataset. + + This dataloader is designed so that it can handle only one sequence, + if more have to be handled one should inherit from this class. + """ + data_folder = 'MOTS20' + + def __init__(self, root_dir: str = 'data', seq_name: Optional[str] = None, + vis_threshold: float = 0.0, img_transform: Namespace = None) -> None: + """ + Args: + seq_name (string): Sequence to take + vis_threshold (float): Threshold of visibility of persons + above which they are selected + """ + super().__init__(root_dir, seq_name, None, vis_threshold, img_transform) + + def get_track_boxes_and_visbility(self) -> Tuple[dict, dict]: + boxes = {} + visibility = {} + + for i in range(1, self.seq_length + 1): + boxes[i] = {} + visibility[i] = {} + + gt_file = self.get_gt_file_path() + if not osp.exists(gt_file): + return boxes, visibility + + mask_objects_per_frame = load_mots_gt(gt_file) + for frame_id, mask_objects in mask_objects_per_frame.items(): + for mask_object in mask_objects: + # class_id = 1 is car + # class_id = 2 is pedestrian + # class_id = 10 IGNORE + if mask_object.class_id == 1: + continue + + bbox = rletools.toBbox(mask_object.mask) + x1, y1, w, h = [int(c) for c in bbox] + bbox = np.array([x1, y1, x1 + w, y1 + h], dtype=np.float32) + + # area = bbox[2] * bbox[3] + # image_id = img_file_name_to_id[f"{seq}_{frame_id:06d}.jpg"] + + # segmentation = { + # 'size': mask_object.mask['size'], + # 'counts': mask_object.mask['counts'].decode(encoding='UTF-8')} + + boxes[frame_id][mask_object.track_id] = bbox + visibility[frame_id][mask_object.track_id] = 1.0 + + return boxes, visibility + + def write_results(self, results: dict, output_dir: str) -> None: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + result_file_path = osp.join(output_dir, f"{self._seq_name}.txt") + + with open(result_file_path, "w") as res_file: + writer = csv.writer(res_file, delimiter=' ') + for i, track in results.items(): + for frame, data in track.items(): + mask = np.asfortranarray(data['mask']) + rle_mask = rletools.encode(mask) + + writer.writerow([ + frame + 1, + i + 1, + 2, # class pedestrian + mask.shape[0], + mask.shape[1], + rle_mask['counts'].decode(encoding='UTF-8')]) + + def load_results(self, results_dir: str) -> dict: + results = {} + + if results_dir is None: + return results + + file_path = osp.join(results_dir, self.results_file_name) + + if not os.path.isfile(file_path): + return results + + mask_objects_per_frame = load_mots_gt(file_path) + + for frame_id, mask_objects in mask_objects_per_frame.items(): + for mask_object in mask_objects: + # class_id = 1 is car + # class_id = 2 is pedestrian + # class_id = 10 IGNORE + if mask_object.class_id == 1: + continue + + bbox = rletools.toBbox(mask_object.mask) + x1, y1, w, h = [int(c) for c in bbox] + bbox = np.array([x1, y1, x1 + w, y1 + h], dtype=np.float32) + + # area = bbox[2] * bbox[3] + # image_id = img_file_name_to_id[f"{seq}_{frame_id:06d}.jpg"] + + # segmentation = { + # 'size': mask_object.mask['size'], + # 'counts': mask_object.mask['counts'].decode(encoding='UTF-8')} + + track_id = mask_object.track_id - 1 + if track_id not in results: + results[track_id] = {} + + results[track_id][frame_id - 1] = {} + results[track_id][frame_id - 1]['mask'] = rletools.decode(mask_object.mask) + results[track_id][frame_id - 1]['bbox'] = bbox.tolist() + results[track_id][frame_id - 1]['score'] = 1.0 + + return results + + def __str__(self) -> str: + return self._seq_name + + +class SegmentedObject: + """ + Helper class for segmentation objects. + """ + def __init__(self, mask: dict, class_id: int, track_id: int) -> None: + self.mask = mask + self.class_id = class_id + self.track_id = track_id + + +def load_mots_gt(path: str) -> dict: + """Load MOTS ground truth from path.""" + objects_per_frame = {} + track_ids_per_frame = {} # Check that no frame contains two objects with same id + combined_mask_per_frame = {} # Check that no frame contains overlapping masks + + with open(path, "r") as gt_file: + for line in gt_file: + line = line.strip() + fields = line.split(" ") + + frame = int(fields[0]) + if frame not in objects_per_frame: + objects_per_frame[frame] = [] + if frame not in track_ids_per_frame: + track_ids_per_frame[frame] = set() + if int(fields[1]) in track_ids_per_frame[frame]: + assert False, f"Multiple objects with track id {fields[1]} in frame {fields[0]}" + else: + track_ids_per_frame[frame].add(int(fields[1])) + + class_id = int(fields[2]) + if not(class_id == 1 or class_id == 2 or class_id == 10): + assert False, "Unknown object class " + fields[2] + + mask = { + 'size': [int(fields[3]), int(fields[4])], + 'counts': fields[5].encode(encoding='UTF-8')} + if frame not in combined_mask_per_frame: + combined_mask_per_frame[frame] = mask + elif rletools.area(rletools.merge([ + combined_mask_per_frame[frame], mask], + intersect=True)): + assert False, "Objects with overlapping masks in frame " + fields[0] + else: + combined_mask_per_frame[frame] = rletools.merge( + [combined_mask_per_frame[frame], mask], + intersect=False) + objects_per_frame[frame].append(SegmentedObject( + mask, + class_id, + int(fields[1]) + )) + + return objects_per_frame diff --git a/src/trackformer/datasets/transforms.py b/src/trackformer/datasets/transforms.py new file mode 100644 index 0000000..09b0727 --- /dev/null +++ b/src/trackformer/datasets/transforms.py @@ -0,0 +1,471 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Transforms and data augmentation for both image + bbox. +""" +import random +from typing import Union + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from ..util.box_ops import box_xyxy_to_cxcywh +from ..util.misc import interpolate + + +def crop(image, target, region): + i, j, h, w = region + target = target.copy() + + if isinstance(image, torch.Tensor): + cropped_image = image[:, j:j + w, i:i + h] + else: + cropped_image = F.crop(image, *region) + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd", "ignore", "track_ids"] + + orig_area = target["area"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "masks" in target: + keep = target['masks'].flatten(1).any(1) + else: + # cropped_boxes = target['boxes'].reshape(-1, 2, 2) + # keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + + # new area must be at least % of orginal area + keep = target["area"] >= orig_area * 0.2 + + for field in fields: + if field in target: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + if isinstance(image, torch.Tensor): + flipped_image = image.flip(-1) + _, width, _ = image.size() + else: + flipped_image = F.hflip(image) + width, _ = image.size + + target = target.copy() + + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] \ + * torch.as_tensor([-1, 1, -1, 1]) \ + + torch.as_tensor([width, 0, width, 0]) + target["boxes"] = boxes + + if "boxes_ignore" in target: + boxes = target["boxes_ignore"] + boxes = boxes[:, [2, 1, 0, 3]] \ + * torch.as_tensor([-1, 1, -1, 1]) \ + + torch.as_tensor([width, 0, width, 0]) + target["boxes_ignore"] = boxes + + if "masks" in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes \ + * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # pad_left, pad_top, pad_right, pad_bottom + padded_image = F.pad(image, padding) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + w, h = padded_image.size + + if "boxes" in target: + # correct xyxy from left and right paddings + target["boxes"] += torch.tensor( + [padding[0], padding[1], padding[0], padding[1]]) + + target["size"] = torch.tensor([h, w]) + if "masks" in target: + # padding_left, padding_right, padding_top, padding_bottom + target['masks'] = torch.nn.functional.pad( + target['masks'], + (padding[0], padding[2], padding[1], padding[3])) + return padded_image, target + + +class RandomCrop: + def __init__(self, size): + # in hxw + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop: + def __init__(self, + min_size: Union[tuple, list, int], + max_size: Union[tuple, list, int] = None): + if isinstance(min_size, int): + min_size = (min_size, min_size) + if isinstance(max_size, int): + max_size = (max_size, max_size) + + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + if self.max_size is None: + w = random.randint(min(self.min_size[0], img.width), img.width) + h = random.randint(min(self.min_size[1], img.height), img.height) + else: + w = random.randint( + min(self.min_size[0], img.width), + min(img.width, self.max_size[0])) + h = random.randint( + min(self.min_size[1], img.height), + min(img.height, self.max_size[1])) + + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop: + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip: + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RepeatUntilMaxObjects: + def __init__(self, transforms, num_max_objects): + self._num_max_objects = num_max_objects + self._transforms = transforms + + def __call__(self, img, target): + num_objects = None + while num_objects is None or num_objects > self._num_max_objects: + img_trans, target_trans = self._transforms(img, target) + num_objects = len(target_trans['boxes']) + return img_trans, target_trans + + +class RandomResize: + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomResizeTargets: + def __init__(self, scale=0.5): + self.scalce = scale + + def __call__(self, img, target=None): + img = F.to_tensor(img) + img_c, img_w, img_h = img.shape + + rescaled_boxes = [] + rescaled_box_images = [] + for box in target['boxes']: + y1, x1, y2, x2 = box.int().tolist() + w = x2 - x1 + h = y2 - y1 + + box_img = img[:, x1:x2, y1:y2] + random_scale = random.uniform(0.5, 2.0) + scaled_width = int(random_scale * w) + scaled_height = int(random_scale * h) + + box_img = F.to_pil_image(box_img) + rescaled_box_image = F.resize( + box_img, + (scaled_width, scaled_height)) + rescaled_box_images.append(F.to_tensor(rescaled_box_image)) + rescaled_boxes.append([y1, x1, y1 + scaled_height, x1 + scaled_width]) + + for box in target['boxes']: + y1, x1, y2, x2 = box.int().tolist() + w = x2 - x1 + h = y2 - y1 + + erase_value = torch.empty( + [img_c, w, h], + dtype=torch.float32).normal_() + + img = F.erase( + img, x1, y1, w, h, erase_value, True) + + for box, rescaled_box_image in zip(target['boxes'], rescaled_box_images): + y1, x1, y2, x2 = box.int().tolist() + w = x2 - x1 + h = y2 - y1 + _, scaled_width, scaled_height = rescaled_box_image.shape + + rescaled_box_image = rescaled_box_image[ + :, + :scaled_width - max(x1 + scaled_width - img_w, 0), + :scaled_height - max(y1 + scaled_height - img_h, 0)] + + img[:, x1:x1 + scaled_width, y1:y1 + scaled_height] = rescaled_box_image + + target['boxes'] = torch.tensor(rescaled_boxes).float() + img = F.to_pil_image(img) + return img, target + + +class RandomPad: + def __init__(self, max_size): + if isinstance(max_size, int): + max_size = (max_size, max_size) + + self.max_size = max_size + + def __call__(self, img, target): + w, h = img.size + pad_width = max(self.max_size[0] - w, 0) # random.randint(0, max(self.max_size[0] - w, 0)) + pad_height = max(self.max_size[1] - h, 0) # random.randint(0, max(self.max_size[1] - h, 0)) + + pad_left = random.randint(0, pad_width) + pad_right = pad_width - pad_left + pad_top = random.randint(0, pad_height) + pad_bottom = pad_height - pad_top + + padding = (pad_left, pad_top, pad_right, pad_bottom) + + return pad(img, target, padding) + + +class RandomSelect: + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor: + def __call__(self, img, target=None): + return F.to_tensor(img), target + + +class RandomErasing: + + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): + self.eraser = T.RandomErasing() + self.p = p + self.scale = scale + self.ratio = ratio + self.value = value + self.inplace = inplace + + def __call__(self, img, target): + if random.uniform(0, 1) < self.p: + img = F.to_tensor(img) + + x, y, h, w, v = self.eraser.get_params( + img, scale=self.scale, ratio=self.ratio, value=self.value) + + img = F.erase(img, x, y, h, w, v, self.inplace) + img = F.to_pil_image(img) + + # target + fields = ['boxes', "labels", "area", "iscrowd", "ignore", "track_ids"] + + if 'boxes' in target: + erased_box = torch.tensor([[y, x, y + w, x + h]]).float() + + lt = torch.max(erased_box[:, None, :2], target['boxes'][:, :2]) # [N,M,2] + rb = torch.min(erased_box[:, None, 2:], target['boxes'][:, 2:]) # [N,M,2] + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + keep = inter[0] <= 0.7 * target['area'] + + left = torch.logical_and( + target['boxes'][:, 0] < erased_box[:, 0], + target['boxes'][:, 2] > erased_box[:, 0]) + left = torch.logical_and(left, inter[0].bool()) + + right = torch.logical_and( + target['boxes'][:, 0] < erased_box[:, 2], + target['boxes'][:, 2] > erased_box[:, 2]) + right = torch.logical_and(right, inter[0].bool()) + + top = torch.logical_and( + target['boxes'][:, 1] < erased_box[:, 1], + target['boxes'][:, 3] > erased_box[:, 1]) + top = torch.logical_and(top, inter[0].bool()) + + bottom = torch.logical_and( + target['boxes'][:, 1] < erased_box[:, 3], + target['boxes'][:, 3] > erased_box[:, 3]) + bottom = torch.logical_and(bottom, inter[0].bool()) + + only_one_crop = (top.float() + bottom.float() + left.float() + right.float()) > 1 + left[only_one_crop] = False + right[only_one_crop] = False + top[only_one_crop] = False + bottom[only_one_crop] = False + + target['boxes'][:, 2][left] = erased_box[:, 0] + target['boxes'][:, 0][right] = erased_box[:, 2] + target['boxes'][:, 3][top] = erased_box[:, 1] + target['boxes'][:, 1][bottom] = erased_box[:, 3] + + for field in fields: + if field in target: + target[field] = target[field][keep] + + return img, target + + +class Normalize: + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target=None): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string diff --git a/src/trackformer/datasets/vidhoi.py b/src/trackformer/datasets/vidhoi.py new file mode 100644 index 0000000..babf679 --- /dev/null +++ b/src/trackformer/datasets/vidhoi.py @@ -0,0 +1,168 @@ +import json +import math +import numpy as np +from .coco import CocoDetection, make_coco_transforms +from pathlib import Path +import random +import copy +import torch +from . import transforms as T + +class VidHOI(CocoDetection): + + def __init__(self, img_folder, ann_file, transforms, return_masks, + prev_frame=False, prev_frame_rnd_augs=0.0, norm_transform=None, clip_length=None, return_box_instant_trajectories=False): + super(VidHOI, self).__init__( + img_folder, ann_file, transforms, return_masks, False, + norm_transform, prev_frame, prev_frame_rnd_augs, clip_length=clip_length, dataset_name='vidhoi') + + self.return_box_instant_trajectories = return_box_instant_trajectories + + def _add_frame_to_target(self, image_id, random_state): + random.setstate(random_state) + frame_img, frame_target = self._getitem_from_id(image_id) + frame_img, frame_target = self._norm_transforms(frame_img, frame_target) + + if self.return_box_instant_trajectories: + frame_target = self._get_box_instant_trajectories(image_id, frame_target) + + return frame_img, frame_target + + def sequence_infos(self): + seqs = self.coco.dataset['sequences'] + startend_image_ids = self.coco.dataset['sequence_startend_image_ids'] + startend_idx = [(self.ids.index(se[0]), self.ids.index(se[1])) for se in startend_image_ids] + return seqs, startend_idx + + def _get_box_instant_trajectories(self, idx, target): + org_img_id = self.ids[idx] + org_img_info = self.coco.imgs[org_img_id] + vfolder, vkey = org_img_info['video_key'].split('/') + if 'train' in self.root: + vidor_anntation_file = f"data/VidHOI/VidOR/annotations/train/{vfolder}/{vkey}.json" + elif 'validation' in self.root: + vidor_anntation_file = f"data/VidHOI/VidOR/annotations/validation/{vfolder}/{vkey}.json" + with open(vidor_anntation_file, 'r') as f: + org_annotations = json.load(f) + + orig_h, orig_w = target['orig_size'] + window_size = 12 + trajectories = torch.zeros(len(target['track_ids']), window_size*2*4) + for fid, fboxes in enumerate(org_annotations['trajectories'][int(org_img_info['frame_key'])-window_size:int(org_img_info['frame_key'])+window_size]): + tid2info = {x['tid']: x for x in fboxes} + for obj_id, tid in enumerate(target['track_ids']): + if tid.item() in tid2info: + bbox = tid2info[tid.item()]['bbox'] + trajectories[obj_id, fid*4:4*(fid+1)] = torch.tensor([bbox['xmin']/orig_w, bbox['ymin']/orig_h, bbox['xmax']/orig_w, bbox['ymax']/orig_h]) + + target['box_instant_trajectories'] = trajectories + return target + + def __getitem__(self, idx): + random_state = random.getstate() + + if self.clip_mode: + while(True): # get valid video clip + org_img_id = self.ids[idx] + org_img_info = self.coco.imgs[org_img_id] + if org_img_info['frame_id'] < self.clip_length-1: + idx += self.clip_length + else: + break + + img, target = self._getitem_from_id(idx) + img, target = self._norm_transforms(img, target) + + if self.return_box_instant_trajectories: + target = self._get_box_instant_trajectories(idx, target) + + if self.clip_mode: + start_id = self.ids.index(org_img_info['first_frame_image_id']) + prev_image_ids = np.sort((org_img_info['frame_id'] - np.arange(0, org_img_info['frame_id']+1))[1:self.clip_length]) + start_id + + prev_frame_imgs, prev_frame_targets = [], [] + for prev_image_id in prev_image_ids: + frame_img, frame_target = self._add_frame_to_target(prev_image_id, random_state) + prev_frame_imgs.append(frame_img) + prev_frame_targets.append(frame_target) + + # compose clip + append_num = self.clip_length - len(prev_frame_imgs) + img = prev_frame_imgs + [img.clone() for _ in range(append_num)] + target = prev_frame_targets + [copy.deepcopy(target) for _ in range(append_num)] + + return img, target + + +class VidHOIVideo(VidHOI): + + def __init__(self, img_folder, ann_file, transforms, return_masks=False, norm_transform=None, clip_length=3): + super(VidHOIVideo, self).__init__(img_folder, ann_file, transforms, return_masks, norm_transform=norm_transform) + + self.clip_length = clip_length + self.videos, self.startend_idx = self.sequence_infos() + + def sequence_infos(self): + seqs = self.coco.dataset['sequences'] + startend_image_ids = self.coco.dataset['sequence_startend_image_ids'] + + seq_clips, seq_startend_idx = [], [] + for vid, se in zip(seqs, startend_image_ids): + vid_length = se[1] - se[0] + 1 + if vid_length < 2: continue + seg_num = math.ceil(vid_length / self.clip_length) + for seg_id in range(seg_num): + if seg_id == seg_num-1: + seq_startend_idx.append((max(se[0], se[1]-self.clip_length+1), se[1])) + else: + seq_startend_idx.append((se[0]+self.clip_length*seg_id, se[0]+self.clip_length*(seg_id+1)-1)) + seq_clips.append(f"{vid}_{seg_id}") + + return seq_clips, seq_startend_idx + + def __getitem__(self, idx): + random_state = random.getstate() + + frames, targets = [], [] + start, end = self.startend_idx[idx] + for fid in range(start, end+1): + img, target = self._add_frame_to_target(fid, random_state) + frames.append(img) + targets.append(target) + + return frames, targets + + def __len__(self): + return len(self.videos) + +def build_vidhoi(image_set, args): + root = Path(args.vidhoi_path) + assert root.exists(), f'provided VidHOI path {root} does not exist' + + split = getattr(args, f"{image_set}_split") + + img_folder = f"{root}/frames/train" if image_set == 'train' else f"{root}/frames/validation" + ann_file = root / f"VidHOI_annotations/{split}_cocofmt.json" + + if args.object_detector == 'frcnn': + transforms = None + norm_transforms = T.Compose([T.ToTensor()]) + else: + transforms, norm_transforms = make_coco_transforms( + image_set, args.img_transform, no_crop=True) + + if args.sgg_use_STTran: + dataset = VidHOIVideo(img_folder, ann_file, transforms=transforms, norm_transform=norm_transforms, clip_length=args.clip_length) + else: + dataset = VidHOI( + img_folder, ann_file, + transforms=transforms, + norm_transform=norm_transforms, + return_masks=args.masks, + prev_frame=args.tracking, + prev_frame_rnd_augs=args.track_prev_frame_rnd_augs, + clip_length=args.clip_length, + return_box_instant_trajectories=(args.hoi_oracle_mode and args.hoi_oracle_mode_use_instant_trajectory) + ) + + return dataset diff --git a/src/trackformer/datasets/vidhoi_eval.py b/src/trackformer/datasets/vidhoi_eval.py new file mode 100644 index 0000000..ab1231d --- /dev/null +++ b/src/trackformer/datasets/vidhoi_eval.py @@ -0,0 +1,396 @@ +import copy +import itertools +import numpy as np +import torch +from collections import defaultdict +import json +import matplotlib.pyplot as plt +from ..util import box_ops +from ..util import misc as utils + +# from ST-HOI paper/code +TEMPORAL_predicates = ['towards', 'away', 'pull', 'caress', 'push', 'press', 'wave', 'hit', 'lift', 'pat', 'grab', 'chase', 'release', 'wave_hand_to', 'squeeze', 'kick', 'shout_at', 'throw', 'smell', 'knock', 'lick', 'open', 'close', 'get_on', 'get_off'] + +def argsort_desc(scores): + """ + Returns the indices that sort scores descending in a smart way + :param scores: Numpy array of arbitrary size + :return: an array of size [numel(scores), dim(scores)] where each row is the index you'd + need to get the score. + """ + return np.column_stack(np.unravel_index(np.argsort(-scores.ravel()), scores.shape)) + +class VidHOIEvaluator(): + def __init__(self, args): + self.overlap_iou = 0.5 + self.max_hois = 100 + + # meta_infos + train_annotation_file = f"{args.vidhoi_path}/VidHOI_annotations/{args.train_split}_cocofmt.json" + with open(train_annotation_file, 'r') as f: + train_annotations = json.load(f) + self.object_categories = [x['name'] for x in train_annotations['categories']][1:] # remove __background__ + self.predicates = [x['name'] for x in train_annotations['predicate_categories']] + + triplet_counts = np.zeros((len(self.predicates), len(self.object_categories))) + for video_key, frame_dict in train_annotations['relation_annotations'].items(): + for frame_key, rels in frame_dict.items(): + for rel in rels: + triplet_counts[rel['predicate'], rel['object_class']] += 1 + self.correct_mat = (triplet_counts>0).astype("float64") + + # initialize + self.sum_gts = {} + self.gt_triplets = [] + self.preds = [] + self.gts = [] + + self.fp = defaultdict(list) + self.tp = defaultdict(list) + self.score = defaultdict(list) + + def sttran_update(self, gts, outputs, box_preds): + pred_top_rel_pairs = [] + for idx, frame_gt in enumerate(gts): + frame_box_pred = box_preds[frame_gt['image_id'].item()] + + # relation predictions + pred_boxes = frame_box_pred['boxes'].cpu().numpy() + pred_classes = frame_box_pred['labels'].cpu().numpy() + + rel_pairs = outputs['pred_rel_pairs'][idx].cpu().numpy() # scores + predicate_scores = outputs['pred_relations'][idx].sigmoid().cpu() + triplet_scores = predicate_scores * frame_box_pred['scores'][rel_pairs[:,0]].unsqueeze(1) * frame_box_pred['scores'][rel_pairs[:,1]].unsqueeze(1) + + score_mask = self.correct_mat[:, pred_classes[rel_pairs[:,1]]].transpose() # mask unseen triplets + triplet_scores = triplet_scores.numpy() * score_mask + + score_inds = argsort_desc(triplet_scores)[:self.max_hois] # get top100 + pred_rels = np.column_stack([rel_pairs[score_inds[:, 0]], score_inds[:, 1]]) # + rel_scores = triplet_scores[score_inds[:, 0], score_inds[:, 1]] + pred_top_rel_pairs.append(pred_rels) + + # gt + gt_boxes = frame_gt['boxes'].cpu().numpy() + gt_classes = frame_gt['labels'].cpu().numpy()-1 + gt_relations = frame_gt['relation_map'].nonzero().cpu().numpy() + if len(gt_relations) == 0: continue # skip frames with no gt relations, follow ST-HOI + self.gts.append({ + 'annotations': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(gt_boxes, gt_classes)], + 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in gt_relations] + }) + for hoi in self.gts[-1]['hoi_annotation']: + triplet = (self.gts[-1]['annotations'][hoi['subject_id']]['category_id'], + self.gts[-1]['annotations'][hoi['object_id']]['category_id'], + hoi['category_id']) + if triplet not in self.gt_triplets: self.gt_triplets.append(triplet) + if triplet not in self.sum_gts: self.sum_gts[triplet] = 0 + self.sum_gts[triplet] += 1 + + self.preds.append({ + 'predictions': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(pred_boxes, pred_classes)], + 'hoi_prediction': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2], 'score': score} + for hoi, score in zip(pred_rels, rel_scores)] + }) + + return pred_top_rel_pairs + + def update(self, gts, outputs, box_preds): + pred_top_rel_pairs = [] + for idx, frame_gt in enumerate(gts): + frame_box_pred = box_preds[frame_gt['image_id'].item()] + img_h, img_w = frame_gt['orig_size'] + # relation predictions + pred_boxes = frame_box_pred['boxes'].cpu().numpy() + pred_classes = frame_box_pred['labels'].cpu().numpy() + + rel_pairs = outputs['pred_rel_pairs'][idx].cpu().numpy() # scores + predicate_scores = outputs['pred_relations'][idx].sigmoid() + triplet_scores = predicate_scores * frame_box_pred['scores'][rel_pairs[:,0]].unsqueeze(1) * frame_box_pred['scores'][rel_pairs[:,1]].unsqueeze(1) + # triplet_scores = predicate_scores * outputs['pred_relation_exists'][idx].sigmoid().unsqueeze(-1) + + score_mask = self.correct_mat[:, pred_classes[rel_pairs[:,1]]].transpose() # mask unseen triplets + triplet_scores = triplet_scores.cpu().numpy() * score_mask + + score_inds = argsort_desc(triplet_scores)[:self.max_hois] # get top100 + pred_rels = np.column_stack([rel_pairs[score_inds[:, 0]], score_inds[:, 1]]) # + rel_scores = triplet_scores[score_inds[:, 0], score_inds[:, 1]] + pred_top_rel_pairs.append(pred_rels) + + # gt + boxes = box_ops.box_cxcywh_to_xyxy(frame_gt['boxes']) + scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).to(boxes.device) + gt_boxes = (boxes * scale_fct).cpu().numpy() + + gt_classes = frame_gt['labels'].cpu().numpy() + gt_relations = frame_gt['relation_map'].nonzero().cpu().numpy() + if len(gt_relations) == 0: # skip frames with no gt relations, follow ST-HOI + continue + self.gts.append({ + 'annotations': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(gt_boxes, gt_classes)], + 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in gt_relations] + }) + for hoi in self.gts[-1]['hoi_annotation']: + triplet = (self.gts[-1]['annotations'][hoi['subject_id']]['category_id'], + self.gts[-1]['annotations'][hoi['object_id']]['category_id'], + hoi['category_id']) + if triplet not in self.gt_triplets: self.gt_triplets.append(triplet) + if triplet not in self.sum_gts: self.sum_gts[triplet] = 0 + self.sum_gts[triplet] += 1 + + self.preds.append({ + 'predictions': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(pred_boxes, pred_classes)], + 'hoi_prediction': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2], 'score': score} + for hoi, score in zip(pred_rels, rel_scores)] + }) + # self.preds.append({ + # 'predictions': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(gt_boxes, gt_classes)], + # 'hoi_prediction': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2], 'score': 1} for hoi in gt_relations] + # }) # test evaluate with GT + + return pred_top_rel_pairs + + def synchronize_between_processes(self): + self.gts = list(itertools.chain(*utils.all_gather(self.gts))) + self.preds = list(itertools.chain(*utils.all_gather(self.preds))) + self.gt_triplets = list(set(itertools.chain(*utils.all_gather(self.gt_triplets)))) + assert len(self.gts) == len(self.preds) + + all_sum_gts = utils.all_gather(self.sum_gts) + merged_sum_gts = all_sum_gts[0].copy() + for single_gts in all_sum_gts[1:]: + for triplet, count in single_gts.items(): + if triplet in merged_sum_gts: + merged_sum_gts[triplet] += count + else: + merged_sum_gts[triplet] = count + self.sum_gts = merged_sum_gts + + def evaluate(self): + for img_id, (img_preds, img_gts) in enumerate(zip(self.preds, self.gts)): + print(f"Evaluating Score Matrix... : [{(img_id+1):>4}/{len(self.gts):<4}]", flush=True, end="\r") + pred_bboxes = img_preds['predictions'] + gt_bboxes = img_gts['annotations'] + pred_hois = img_preds['hoi_prediction'] + gt_hois = img_gts['hoi_annotation'] + if len(gt_bboxes) != 0: + if len(pred_bboxes) == 0: continue + bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes) + self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps) + else: + for pred_hoi in pred_hois: + triplet = [pred_bboxes[pred_hoi['subject_id']]['category_id'], + pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id']] + if triplet not in self.gt_triplets: + continue + self.tp[triplet].append(0) + self.fp[triplet].append(1) + self.score[triplet].append(pred_hoi['score']) + print(f"[stats] Score Matrix Generation completed!! ") + map = self.compute_map() + return map + + # refer to: https://github.com/coldmanck/VidHOI/blob/master/vidor_eval.ipynb + def set_rare_nonrare_triplets(self, count_threshold=25): + rare_triplets, nonrare_triplets = [], [] + for triplet, count in self.sum_gts.items(): + if count < count_threshold: + rare_triplets.append(triplet) + else: + nonrare_triplets.append(triplet) + return rare_triplets, nonrare_triplets + + def compute_map(self): + ap = defaultdict(lambda: 0) + max_recall = defaultdict(lambda: 0) + temporal_predicate_inds = [self.predicates.index(p) for p in TEMPORAL_predicates] + temporal_ap = defaultdict(lambda: 0) + non_temporal_ap = defaultdict(lambda: 0) + per_predicate_stats = {} + + # rare & nonrare eval + rare_triplets, nonrare_triplets = self.set_rare_nonrare_triplets() + rare_ap = defaultdict(lambda: 0) + nonrare_ap = defaultdict(lambda: 0) + + for triplet in self.gt_triplets: + sum_gts = self.sum_gts[triplet] + if sum_gts == 0: + continue + + tp = np.array((self.tp[triplet])) + fp = np.array((self.fp[triplet])) + if len(tp) == 0: + # ST-HOI just skip these triplets, it's a bug!! (https://github.com/coldmanck/VidHOI/blob/master/vidor_eval.ipynb) + ap[triplet] = 0 + max_recall[triplet] = 0 + if triplet[-1] in temporal_predicate_inds: + temporal_ap[triplet] = 0 + else: + non_temporal_ap[triplet] = 0 + + if triplet in rare_triplets: + rare_ap[triplet] = 0 + elif triplet in nonrare_triplets: + nonrare_ap[triplet] = 0 + continue + + score = np.array(self.score[triplet]) + sort_inds = np.argsort(-score) + fp = fp[sort_inds] + tp = tp[sort_inds] + fp = np.cumsum(fp) + tp = np.cumsum(tp) + rec = tp / sum_gts + prec = tp / (fp + tp) + ap[triplet] = self.voc_ap(rec, prec) + max_recall[triplet] = np.amax(rec) + if triplet[-1] in temporal_predicate_inds: + temporal_ap[triplet] = ap[triplet] + else: + non_temporal_ap[triplet] = ap[triplet] + + if triplet in rare_triplets: + rare_ap[triplet] = ap[triplet] + elif triplet in nonrare_triplets: + nonrare_ap[triplet] = ap[triplet] + + # per predicate stats + predicate_name = self.predicates[triplet[-1]] + if predicate_name not in per_predicate_stats: + per_predicate_stats[predicate_name] = { + 'is_temporal': predicate_name in TEMPORAL_predicates, + 'triplets': [], + 'triplet_aps': [], + 'triplets_gt_counts': [] + } + + per_predicate_stats[predicate_name]['triplets'].append(triplet) + per_predicate_stats[predicate_name]['triplet_aps'].append(ap[triplet]) + per_predicate_stats[predicate_name]['triplets_gt_counts'].append(self.sum_gts[triplet]) + + m_ap = np.mean(list(ap.values())) * 100 # percentage + m_max_recall = np.mean(list(max_recall.values())) * 100 + temporal_m_ap = np.mean(list(temporal_ap.values())) * 100 + non_temporal_m_ap = np.mean(list(non_temporal_ap.values())) * 100 + # plt.hist(list(ap.values()), bins=20); plt.show() + + m_rare_ap = np.mean(list(rare_ap.values())) * 100 + m_nonrare_ap = np.mean(list(nonrare_ap.values())) * 100 + + print(f'======================#total triplets={len(self.gt_triplets)} (Temporal/Spatial={len(temporal_ap)}/{len(non_temporal_ap)}, Rare/Non-rare={len(rare_ap)}/{len(nonrare_ap)}), #frames={len(self.gts)}======================') + print(f'mAP (Full / Temporal / Spatial): {m_ap:.2f} / {temporal_m_ap:.2f} / {non_temporal_m_ap:.2f} || {m_rare_ap:.2f} / {m_nonrare_ap:.2f}, mR (Full): {m_max_recall:.2f}') + + ## per-predicate evaluation results + print(f'======================Per-predicate======================') + print(f"name,\tis_temporal,\tgt_counts,\tp-mAP,\tp-wmAP") + for predicate, stat in per_predicate_stats.items(): + predicate_mAP = sum(stat['triplet_aps']) / len(stat['triplet_aps']) + predicate_wmAP = sum(np.array(stat['triplet_aps']) * np.array(stat['triplets_gt_counts'])) / sum(stat['triplets_gt_counts']) # weighted + print(f"{predicate},\t{stat['is_temporal']},\t{sum(stat['triplets_gt_counts'])},\t{predicate_mAP * 100 :.2f},\t{predicate_wmAP * 100 :.2f}") + per_predicate_stats[predicate].update({'mAP': predicate_mAP, 'wmAP': predicate_wmAP}) + # pmAP_all = np.mean([x['mAP'] for x in per_predicate_stats.values()]) * 100 + # pmAP_temporal = np.mean([x['mAP'] for x in per_predicate_stats.values() if x['is_temporal']]) * 100 + # pmAP_spatial = np.mean([x['mAP'] for x in per_predicate_stats.values() if not x['is_temporal']]) * 100 + # print(f"\n Predicate-level mAP (Full / Temporal / Spatial): {pmAP_all:.2f} / {pmAP_temporal:.2f} / {pmAP_spatial:.2f}") + + return {'mAP': m_ap, 'mean max recall': m_max_recall} + + def voc_ap(self, rec, prec): + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + return ap + + def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps): + pos_pred_ids = match_pairs.keys() + vis_tag = np.zeros(len(gt_hois)) + pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) + if len(pred_hois) != 0: + for pred_hoi in pred_hois: + is_match = 0 + if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and pred_hoi['object_id'] in pos_pred_ids: + pred_sub_ids = match_pairs[pred_hoi['subject_id']] + pred_obj_ids = match_pairs[pred_hoi['object_id']] + pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] + pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']] + pred_category_id = pred_hoi['category_id'] + max_overlap = 0 + max_gt_hoi = 0 + for gt_hoi in gt_hois: + if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids and pred_category_id == gt_hoi['category_id']: + is_match = 1 + min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])], + pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])]) + if min_overlap_gt > max_overlap: + max_overlap = min_overlap_gt + max_gt_hoi = gt_hoi + triplet = (pred_bboxes[pred_hoi['subject_id']]['category_id'], pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id']) + if triplet not in self.gt_triplets: + continue + if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0: + self.fp[triplet].append(0) + self.tp[triplet].append(1) + vis_tag[gt_hois.index(max_gt_hoi)] = 1 + else: + self.fp[triplet].append(1) + self.tp[triplet].append(0) + self.score[triplet].append(pred_hoi['score']) + + def compute_iou_mat(self, bbox_list1, bbox_list2): + iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) + if len(bbox_list1) == 0 or len(bbox_list2) == 0: + return {} + for i, bbox1 in enumerate(bbox_list1): + for j, bbox2 in enumerate(bbox_list2): + iou_i = self.compute_IOU(bbox1, bbox2) + iou_mat[i, j] = iou_i + + iou_mat_ov=iou_mat.copy() + iou_mat[iou_mat>=self.overlap_iou] = 1 + iou_mat[iou_mat 0: + for i, pred_id in enumerate(match_pairs[1]): + if pred_id not in match_pairs_dict.keys(): + match_pairs_dict[pred_id] = [] + match_pair_overlaps[pred_id]=[] + match_pairs_dict[pred_id].append(match_pairs[0][i]) + match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id]) + return match_pairs_dict, match_pair_overlaps + + def compute_IOU(self, bbox1, bbox2): + if isinstance(bbox1['category_id'], str): + bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) + if isinstance(bbox2['category_id'], str): + bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) + if bbox1['category_id'] == bbox2['category_id']: + rec1 = bbox1['bbox'] + rec2 = bbox2['bbox'] + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1) + S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + left_line = max(rec1[1], rec2[1]) + right_line = min(rec1[3], rec2[3]) + top_line = max(rec1[0], rec2[0]) + bottom_line = min(rec1[2], rec2[2]) + # judge if there is an intersect + if left_line >= right_line or top_line >= bottom_line: + return 0 + else: + intersect = (right_line - left_line+1) * (bottom_line - top_line+1) + return intersect / (sum_area - intersect) + else: + return 0 diff --git a/src/trackformer/engine.py b/src/trackformer/engine.py new file mode 100644 index 0000000..929716b --- /dev/null +++ b/src/trackformer/engine.py @@ -0,0 +1,613 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Train and eval functions used in main.py +""" +import logging +import math +import os +import sys +from typing import Iterable + +import torch +from track import ex +import copy +from torch.utils.data import DataLoader +from tqdm import tqdm +import motmetrics as mm +import numpy as np +import time + +from trackformer.datasets import get_coco_api_from_dataset +from trackformer.datasets.coco_eval import CocoEvaluator +from trackformer.datasets.panoptic_eval import PanopticEvaluator +from trackformer.datasets.actiongenome_eval import BasicSceneGraphEvaluator +from trackformer.datasets.vidhoi_eval import VidHOIEvaluator +from trackformer.models.detr_segmentation import DETRSegm +from trackformer.util import misc as utils +from trackformer.util.box_ops import box_iou, box_cxcywh_to_xywh, box_xyxy_to_xywh +from trackformer.util.track_utils import evaluate_mot_accums +from trackformer.vis import vis_results +from trackformer.util.plot_utils import check_prediction, check_annotation +from base_trackers import CentroidTracker, SORT, IOUTracker, BYTETracker + +def make_results(outputs, targets, postprocessors, tracking, return_only_orig=True): + target_sizes = torch.stack([t["size"] for t in targets], dim=0) + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + + results = None + if not return_only_orig: + results = postprocessors['bbox'](outputs, target_sizes) + results_orig = postprocessors['bbox'](outputs, orig_target_sizes) + + # # targets as predictions + # results_orig = [ + # { + # 'scores': torch.ones_like(targets[0]['labels']), + # 'scores_no_object': torch.zeros_like(targets[0]['labels']), + # 'labels': targets[0]['labels'], + # 'boxes': postprocessors['bbox'].process_boxes(targets[0]['boxes'], orig_target_sizes)[0], + # } + # ] + + if 'segm' in postprocessors: + results_orig = postprocessors['segm']( + results_orig, outputs, orig_target_sizes, target_sizes) + if not return_only_orig: + results = postprocessors['segm']( + results, outputs, target_sizes, target_sizes) + + if results is None: + return results_orig, results + + for i, result in enumerate(results): + target = targets[i] + target_size = target_sizes[i].unsqueeze(dim=0) + + result['target'] = {} + result['boxes'] = result['boxes'].cpu() + + # revert boxes for visualization + for key in ['boxes', 'track_query_boxes']: + if key in target: + target[key] = postprocessors['bbox'].process_boxes( + target[key], target_size)[0].cpu() + + if 'random_boxes' in target: + random_target_sizes = torch.stack([t["random_size"] for t in targets], dim=0) + target['random_boxes'] = postprocessors['bbox'].process_boxes( + target['random_boxes'], random_target_sizes[i].unsqueeze(dim=0))[0].cpu() + + if tracking and 'prev_boxes' in target: + prev_target_sizes = torch.stack([t["prev_size"] for t in targets], dim=0) + target['prev_boxes'] = postprocessors['bbox'].process_boxes( + target['prev_boxes'], prev_target_sizes[i].unsqueeze(dim=0))[0].cpu() + + if len(target['track_query_match_ids']): + track_queries_iou, _ = box_iou( + target['boxes'][target['track_query_match_ids']], + result['boxes']) + track_queries_match_mask = target['track_queries_match_mask'] + + box_ids = [box_id for box_id, mask_value in enumerate(track_queries_match_mask == 1) + if mask_value] + + result['track_queries_with_id_iou'] = torch.diagonal(track_queries_iou[:, box_ids]) + + return results_orig, results + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, visualizers: dict, args): + + vis_iter_metrics = None + if visualizers: + vis_iter_metrics = visualizers['iter_metrics'] + + model.train() + criterion.train() + metric_logger = utils.MetricLogger( + args.vis_and_log_interval, + delimiter=" ", + vis=vis_iter_metrics, + debug=args.debug) + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + + for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, epoch)): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + # in order to be able to modify targets inside the forward call we need + # to pass it through as torch.nn.parallel.DistributedDataParallel only + # passes copies + + outputs, targets, *_ = model(samples, targets) + + if isinstance(outputs, list): # per-frame loss computation + loss_dict_list = [criterion(o, [t]) for o, t in zip(outputs, targets)] + frame_num = len(loss_dict_list) + loss_dict = {k: sum([ld[k] for ld in loss_dict_list])/frame_num for k in loss_dict_list[0]} + else: + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = { + f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = { + k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() + + if not math.isfinite(loss_value): + print(f"Loss is {loss_value}, stopping training") + print(loss_dict_reduced) + print(targets) + sys.exit(1) + + optimizer.zero_grad() + losses.backward() + if args.clip_max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_max_norm) + optimizer.step() + + metric_logger.update(loss=loss_value, + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + metric_logger.update(lr=optimizer.param_groups[0]["lr"], + lr_backbone=optimizer.param_groups[1]["lr"]) + + if visualizers and (i == 0 or not i % args.vis_and_log_interval): + _, results = make_results( + outputs, targets, postprocessors, args.tracking, return_only_orig=False) + + vis_results( + visualizers['example_results'], + samples.tensors[0], + results[0], + targets[0], + args.tracking) + + # print('visualizing') + # for j in range(len(targets)): check_annotation(samples, targets, idx=j) + # for j in range(len(targets)): check_prediction(samples, outputs, targets=targets, idx=j) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(model, criterion, postprocessors, data_loader, device, + output_dir: str, visualizers: dict, args, epoch: int = None): + model.eval() + criterion.eval() + + metric_logger = utils.MetricLogger( + args.vis_and_log_interval, + delimiter=" ", + debug=args.debug) + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + + base_ds = get_coco_api_from_dataset(data_loader.dataset) + iou_types = tuple(k for k in ('bbox', 'segm') if k in postprocessors.keys()) + coco_evaluator = CocoEvaluator(base_ds, iou_types) + # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] + + if args.hoi_detection: + actiongenome_hoi_evaluator = BasicSceneGraphEvaluator(mode='sgdet', constraint=False) + vidhoi_evaluator = VidHOIEvaluator(args) if args.dataset == 'vidhoi' else None + + panoptic_evaluator = None + if 'panoptic' in postprocessors.keys(): + panoptic_evaluator = PanopticEvaluator( + data_loader.dataset.ann_file, + data_loader.dataset.ann_folder, + output_dir=os.path.join(output_dir, "panoptic_eval"), + ) + + for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, 'Test:')): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + outputs, targets, *_ = model(samples, targets) + + # print('check results') + # for j in range(len(targets)): check_annotation(samples, targets, idx=j) + # for j in range(len(targets)): check_prediction(samples, outputs, targets=targets, idx=j) + + if isinstance(outputs, list): # per-frame loss computation + loss_dict_list = [criterion(o, [t]) for o, t in zip(outputs, targets)] + frame_num = len(loss_dict_list) + loss_dict = {k: sum([ld[k] for ld in loss_dict_list])/frame_num for k in loss_dict_list[0]} + else: + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + + if visualizers and (i == 0 or not i % args.vis_and_log_interval): + results_orig, results = make_results( + outputs, targets, postprocessors, args.tracking, return_only_orig=False) + + vis_results( + visualizers['example_results'], + samples.tensors[0], + results[0], + targets[0], + args.tracking) + else: + if isinstance(outputs, list): + targets = [targets[-1]] # only evaluate the last frame detection performance + results_orig, _ = make_results(outputs[-1], targets, postprocessors, args.tracking) + else: + results_orig, _ = make_results(outputs, targets, postprocessors, args.tracking) + + # TODO. remove cocoDts from coco eval and change example results output + if coco_evaluator is not None: + results_orig = { + target['image_id'].item(): output + for target, output in zip(targets, results_orig)} + # for target, output in zip([targets[-1]], [results_orig[-1]])} + coco_evaluator.update(copy.deepcopy(results_orig)) + + if panoptic_evaluator is not None: + target_sizes = torch.stack([t["size"] for t in targets], dim=0) + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + + res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) + for j, target in enumerate(targets): + image_id = target["image_id"].item() + file_name = f"{image_id:012d}.png" + res_pano[j]["image_id"] = image_id + res_pano[j]["file_name"] = file_name + + panoptic_evaluator.update(res_pano) + + if args.hoi_detection: + actiongenome_hoi_evaluator.evaluate_scene_graph(targets, outputs, box_preds=results_orig) + if vidhoi_evaluator is not None: + vidhoi_evaluator.update(targets, outputs, box_preds=results_orig) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if args.hoi_detection: + actiongenome_hoi_evaluator.synchronize_between_processes() + actiongenome_hoi_evaluator.print_stats() + if vidhoi_evaluator is not None: + vidhoi_evaluator.synchronize_between_processes() + vidhoi_evaluator.evaluate() + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + if panoptic_evaluator is not None: + panoptic_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + print(f"#images={len(coco_evaluator.coco_eval['bbox'].params.imgIds)}") + panoptic_res = None + if panoptic_evaluator is not None: + panoptic_res = panoptic_evaluator.summarize() + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + if 'bbox' in coco_evaluator.coco_eval: + stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() + if 'segm' in coco_evaluator.coco_eval: + stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() + if panoptic_res is not None: + stats['PQ_all'] = panoptic_res["All"] + stats['PQ_th'] = panoptic_res["Things"] + stats['PQ_st'] = panoptic_res["Stuff"] + + # stats = {} + # TRACK EVAL + if args.tracking and args.tracking_eval: + stats['track_bbox'] = [] + + ex.logger = logging.getLogger("submitit") + + # distribute evaluation of seqs to processes + seqs = data_loader.dataset.sequences + seqs_per_rank = {i: [] for i in range(utils.get_world_size())} + for i, seq in enumerate(seqs): + rank = i % utils.get_world_size() + seqs_per_rank[rank].append(seq) + + # only evaluarte one seq in debug mode + if args.debug: + seqs_per_rank = {k: v[:1] for k, v in seqs_per_rank.items()} + seqs = [s for ss in seqs_per_rank.values() for s in ss] + + dataset_name = seqs_per_rank[utils.get_rank()] + if not dataset_name: + dataset_name = seqs_per_rank[0] + + model_without_ddp = model + if args.distributed: + model_without_ddp = model.module + + # mask prediction is too slow and consumes a lot of memory to + # run it during tracking training. + if isinstance(model, DETRSegm): + model_without_ddp = model_without_ddp.detr + + obj_detector_model = { + 'model': model_without_ddp, + 'post': postprocessors, + 'img_transform': args.img_transform} + + run = ex.run(config_updates={ + 'seed': None, + 'dataset_name': dataset_name, + 'frame_range': data_loader.dataset.frame_range, + 'obj_detector_model': obj_detector_model}) + + mot_accums = utils.all_gather(run.result)[:len(seqs)] + mot_accums = [item for sublist in mot_accums for item in sublist] + + # we compute seqs results on muliple nodes but evaluate the accumulated + # results due to seqs being weighted differently (seg length) + eval_summary, eval_summary_str = evaluate_mot_accums( + mot_accums, seqs) + print(eval_summary_str) + + for metric in ['mota', 'idf1']: + eval_m = eval_summary[metric]['OVERALL'] + stats['track_bbox'].append(eval_m) + + eval_stats = stats['coco_eval_bbox'][:3] + if 'coco_eval_masks' in stats: + eval_stats.extend(stats['coco_eval_masks'][:3]) + if 'track_bbox' in stats: + eval_stats.extend(stats['track_bbox']) + + # VIS + if visualizers: + vis_epoch = visualizers['epoch_metrics'] + y_data = [stats[legend_name] for legend_name in vis_epoch.viz_opts['legend']] + vis_epoch.plot(y_data, epoch) + + visualizers['epoch_eval'].plot(eval_stats, epoch) + + if args.debug: + exit() + + return eval_stats, coco_evaluator + +@torch.no_grad() +def evaluate_video_sgg(model, dataset_val, device, args, postprocessors, query_propagation_threshold=0.5, tracking_det_threshold=0.8): + if args.distributed: + model.module.tracking() + else: + model.tracking() + print(f"Tracker: {args.sgg_postprocessing_tracker}") + + # evaluators + base_ds = get_coco_api_from_dataset(dataset_val) + coco_evaluator = CocoEvaluator(base_ds, ('bbox',)) + mot_accums = [] + if args.hoi_detection: + actiongenome_hoi_evaluator = BasicSceneGraphEvaluator(mode='sgdet', constraint=False, dataset=args.dataset) + vidhoi_evaluator = VidHOIEvaluator(args) if args.dataset == 'vidhoi' else None + + # loading val frames by video + dataset_val.clip_mode = False # !! will affect offline evaluation + video_names, video_startend_idxs = dataset_val.sequence_infos() + if utils.get_world_size() > 1: # for multi-GPUs + video_names = video_names[utils.get_rank()::utils.get_world_size()] + video_startend_idxs = video_startend_idxs[utils.get_rank()::utils.get_world_size()] + + start_time = time.time(); model_inference_time = 0 + # video_names, video_startend_idxs = video_names[12:14], video_startend_idxs[12:14] + for vid, (video_name, v_seg_info) in enumerate(zip(video_names, video_startend_idxs)): + # if vid < 11: continue + print(f'TRACK SEQ: {video_name} ({vid}/{len(video_names)})') + video_loader = DataLoader( + torch.utils.data.Subset(dataset_val, range(v_seg_info[0], v_seg_info[1]+1)), + collate_fn=utils.collate_fn, + num_workers=args.num_workers + ) + + # trackers + if args.sgg_postprocessing_tracker == 'IOUTracker': + tracker = IOUTracker(iou_threshold=0.1) + elif args.sgg_postprocessing_tracker == 'SORT': + tracker = SORT(iou_threshold=0.1) + elif args.sgg_postprocessing_tracker == 'BYTETracker': + tracking_det_threshold = 1e-2 + tracker = BYTETracker(iou_threshold=0.1) + else: + tracker = QuerySlotTracker(num_queries=args.num_queries) + + # actiongenome_hoi_evaluator.reset_result() + mot_accums.append(mm.MOTAccumulator(auto_id=True)) + kept_box_qids, prev_states = None, {} + for frame_id, (samples, targets) in enumerate(tqdm(video_loader, file=sys.stdout)): + assert len(targets) == 1 + frame_tic = time.time() + + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + targets[0].update(prev_states) + + if args.hoi_detection and args.hoi_oracle_mode and 'prev_track_ids' in prev_states: + # match track ids between frames + target_ind_match_matrix = prev_states['prev_track_ids'].unsqueeze(dim=1).eq(targets[0]['track_ids']) + target_ind_matching = target_ind_match_matrix.any(dim=1) + target_ind_matched_idx = target_ind_match_matrix.nonzero()[:, 1] + + # index of prev frame detection in current frame box list + targets[0]['track_query_match_ids'] = target_ind_matched_idx + tracked_qids = prev_states['prev_out_ind'][target_ind_matching] + + # match mask to next frame + track_queries_match_mask = torch.zeros(args.num_queries).float() + track_queries_match_mask[tracked_qids] = 1 # tracked in current frame + track_queries_match_mask[prev_states['prev_out_ind'][~target_ind_matching]] = -1 # disappeared in current frame + targets[0]['track_queries_match_mask'] = track_queries_match_mask.to(device) + + outputs, *_ = model(samples, targets) + + # collect frame-level evaluation + results_post, _ = make_results(outputs, targets, postprocessors, tracking=False) + results_orig = {target['image_id'].item(): output for target, output in zip(targets, results_post)} + + if args.hoi_detection and args.hoi_oracle_mode: + kept_box_mask = torch.zeros_like(results_post[0]['scores']).bool() + kept_box_mask[outputs['match_res'][0]] = True + ## transfer matching to next frame + prev_states.update({ + 'prev_track_ids': targets[0]['track_ids'][outputs['match_res'][1]], + 'prev_out_ind': outputs['match_res'][0] + }) + # print(outputs['match_res'][0], targets[0]['track_ids'][outputs['match_res'][1]]) + else: + # NMS before propagation + suppress_ids = apply_nms(results_post[0], kept_box_qids) + results_post[0]['scores'][suppress_ids] = 0 + kept_box_mask = results_post[0]['scores'] > query_propagation_threshold # box kept for propagation + kept_box_qids = kept_box_mask.nonzero()[:, 0].tolist() + + # detection evaluation + coco_evaluator.update(copy.deepcopy(results_orig)) + if args.hoi_detection: + top_pred_rel_pairs = actiongenome_hoi_evaluator.evaluate_scene_graph(targets, outputs, box_preds=results_orig) + if vidhoi_evaluator is not None: + top_pred_rel_pairs = vidhoi_evaluator.update(targets, outputs, box_preds=results_orig) + + # check_annotation(samples, targets, idx=0) + # check_prediction(samples, outputs, targets=targets, idx=0, threshold=0.2, + # top_pred_rel_pairs=top_pred_rel_pairs, save_fig_dir=f"{args.output_dir}/demo/{video_name}") + + # mot eval + tracking_kept_box_mask = results_post[0]['scores'] > tracking_det_threshold # output tracked boxes + if isinstance(tracker, QuerySlotTracker): + pred_boxes = box_cxcywh_to_xywh(outputs['pred_boxes'][0][tracking_kept_box_mask]).cpu().numpy() + pred_labels = results_post[0]['labels'][tracking_kept_box_mask].cpu().numpy() + pred_track_ids = tracker.update(frame_id, tracking_kept_box_mask.nonzero()[:, 0].tolist()) + else: + tracks = tracker.update(box_cxcywh_to_xywh(outputs['pred_boxes'][0][tracking_kept_box_mask]).cpu().numpy(), + results_post[0]['scores'][tracking_kept_box_mask].cpu().numpy(), + results_post[0]['labels'][tracking_kept_box_mask].cpu().numpy()) + pred_boxes, pred_track_ids, pred_labels = [], [], [] + for track in tracks: + frame_index, track_id, bbox_left, bbox_top, bbox_width, bbox_height, score, object_category, truncation, occlusion = track + pred_boxes.append([bbox_left, bbox_top, bbox_width, bbox_height]) + pred_track_ids.append(track_id) + pred_labels.append(object_category) + pred_boxes, pred_labels = np.array(pred_boxes), np.array(pred_labels) + + gt_boxes, gt_track_ids, gt_labels = box_cxcywh_to_xywh(targets[0]['boxes']).cpu().numpy(), targets[0]['track_ids'].tolist(), targets[0]['labels'].cpu().numpy() + distance = mm.distances.iou_matrix(gt_boxes, pred_boxes, max_iou=0.5) + if len(distance) > 0: + distance = np.where(gt_labels[:, None] == pred_labels[None, :], distance, np.nan) + mot_accums[-1].update(gt_track_ids, pred_track_ids, distance) + + # temporal propagation + prev_states.update({ + 'track_query_boxes': outputs['pred_boxes'][0], + 'track_query_hs_embeds': outputs['hs_embed'][0].to(device), + 'track_token_propagation_mask': kept_box_mask.float() + }) + if args.hoi_relation_propagation_on_inference and len(top_pred_rel_pairs[0]) > 0: + prev_states.update({'prev_top_rel_pairs': torch.from_numpy(np.unique(top_pred_rel_pairs[0][:args.num_hoi_queries//2, :2], axis=0))}) + if args.hoi_detection and args.hoi_use_temporal_dynamics: + if 'temporal_dynamics_feature_bank' in prev_states: + prev_states['temporal_dynamics_feature_bank'] = torch.cat((outputs['hs_embed'].to(device), prev_states['temporal_dynamics_feature_bank']), dim=0)[:args.hoi_use_temporal_dynamics_prev_length] + prev_states['temporal_dynamics_feature_mask'] = torch.cat(((kept_box_mask.float()==0).unsqueeze(0), prev_states['temporal_dynamics_feature_mask']), dim=0)[:args.hoi_use_temporal_dynamics_prev_length] + else: + prev_states['temporal_dynamics_feature_bank'] = outputs['hs_embed'].to(device) + prev_states['temporal_dynamics_feature_mask'] = (kept_box_mask.float()==0).unsqueeze(0) + model_inference_time += time.time() - frame_tic + + total_inference_time = time.time() - start_time + # eval results + mot_accums_all, video_names_all = utils.all_gather(mot_accums), utils.all_gather(video_names) + eval_summary, eval_summary_str = evaluate_mot_accums(sum(mot_accums_all, []), sum(video_names_all, [])) + print(eval_summary_str) + print(f'#videos={len(sum(video_names_all, []))}') + + print(f"model_inference_time={model_inference_time}s, total_inference_time={total_inference_time}s (include data loading, processing etc.)") + if args.hoi_detection: + actiongenome_hoi_evaluator.synchronize_between_processes() + actiongenome_hoi_evaluator.print_stats() + if vidhoi_evaluator is not None: + vidhoi_evaluator.synchronize_between_processes() + vidhoi_evaluator.evaluate() + + print(f"Model: time_per_frame={model_inference_time/len(vidhoi_evaluator.gts)*1000 :.2f}ms, frame_per_second={len(vidhoi_evaluator.gts)/model_inference_time :.2f}") + print(f"Model(+dataLoading etc.): time_per_frame={total_inference_time/len(vidhoi_evaluator.gts)*1000 :.2f}ms, frame_per_second={len(vidhoi_evaluator.gts)/total_inference_time :.2f}") + + coco_evaluator.synchronize_between_processes() + coco_evaluator.accumulate() + coco_evaluator.summarize() + return + +# merge boxes (NMS) +def apply_nms(res, kept_box_qids=None, threshold=0.7): + inst_scores, inst_labels, xyxy_boxes = res['scores'].clone(), res['labels'], res['boxes'] + if kept_box_qids is not None: + inst_scores[kept_box_qids] *= 4 # we prefer to keep tracked boxes in previous frames + + box_areas = (xyxy_boxes[:, 2:] - xyxy_boxes[:, :2]).prod(-1) + box_area_sum = box_areas.unsqueeze(1) + box_areas.unsqueeze(0) + + union_boxes = torch.cat([torch.min(xyxy_boxes.unsqueeze(1)[:, :, :2], xyxy_boxes.unsqueeze(0)[:, :, :2]), + torch.max(xyxy_boxes.unsqueeze(1)[:, :, 2:], xyxy_boxes.unsqueeze(0)[:, :, 2:])], dim=-1) + union_area = (union_boxes[:,:,2:] - union_boxes[:,:,:2]).prod(-1) + iou = torch.clamp(box_area_sum - union_area, min=0) / union_area + box_match_mat = torch.logical_and(iou > threshold, inst_labels.unsqueeze(1) == inst_labels.unsqueeze(0)) + + suppress_ids = [] + for box_match in box_match_mat: + group_ids = box_match.nonzero(as_tuple=False).squeeze(1) + if len(group_ids) > 1: + max_score_inst_id = group_ids[inst_scores[group_ids].argmax()] + bg_ids = group_ids[group_ids!=max_score_inst_id] + suppress_ids.append(bg_ids) + box_match_mat[:, bg_ids] = False + if len(suppress_ids) > 0: + suppress_ids = torch.cat(suppress_ids, dim=0) + return suppress_ids + +class QuerySlotTracker: + def __init__(self, num_queries, max_lost=30): + self.num_queries = num_queries + self.max_lost = max_lost + + # states + self.query_been_activated = np.zeros(num_queries, dtype=bool) + self.query_last_tracked_frame_id = np.zeros(num_queries, dtype=int) - 1 + self.query_assigned_track_ids = np.arange(num_queries) + + def update(self, current_frame_id, detected_query_ids): + track_ids = [] + for qid in detected_query_ids: + if self.query_been_activated[qid] and (current_frame_id - self.query_last_tracked_frame_id[qid] > self.max_lost): + self.query_assigned_track_ids[qid] += self.num_queries # assign new track_id to query slot + # print(f'{current_frame_id} - {self.query_last_tracked_frame_id[qid]} ==> {self.query_assigned_track_ids[qid]}') + + self.query_been_activated[qid] = True + self.query_last_tracked_frame_id[qid] = current_frame_id + track_ids.append(self.query_assigned_track_ids[qid]) + + return track_ids diff --git a/src/trackformer/models/__init__.py b/src/trackformer/models/__init__.py new file mode 100644 index 0000000..5d5f366 --- /dev/null +++ b/src/trackformer/models/__init__.py @@ -0,0 +1,142 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import torch + +from .backbone import build_backbone +from .deformable_detr import DeformableDETR, DeformablePostProcess +from .deformable_transformer import build_deforamble_transformer +from .detr import DETR, PostProcess, SetCriterion +from .detr_segmentation import (DeformableDETRSegm, DeformableDETRSegmTracking, + DETRSegm, DETRSegmTracking, + PostProcessPanoptic, PostProcessSegm) +from .detr_tracking import DeformableDETRTracking, DETRTracking +from .matcher import build_matcher +from .transformer import build_transformer +from .detr_hoi import DeformableDETRHoi +from .detr_vsgg import DeformableDETRVsgg + +def build_model(args): + if args.dataset == 'coco': + num_classes = 91 + elif args.dataset == 'coco_panoptic': + num_classes = 250 + elif args.dataset in ['coco_person', 'mot', 'mot_crowdhuman']: + num_classes = 1 + elif args.dataset in ['actiongenome', 'vidhoi']: + num_classes = args.num_classes # num exclude __background__ + else: + raise NotImplementedError + + device = torch.device(args.device) + backbone = build_backbone(args) + matcher = build_matcher(args) + + detr_kwargs = { + 'backbone': backbone, + 'num_classes': num_classes - 1 if args.focal_loss else num_classes, + 'num_queries': args.num_queries, + 'aux_loss': args.aux_loss,} + + tracking_kwargs = { + 'track_query_false_positive_prob': args.track_query_false_positive_prob, + 'track_query_false_negative_prob': args.track_query_false_negative_prob, + 'track_query_noise': args.track_query_noise, + 'matcher': matcher, + 'track_query_propagation_strategy': args.track_query_propagation_strategy, + 'tracking_token_propagation': args.tracking_token_propagation, + 'clip_mode': isinstance(args.clip_length, int), + 'token_propagation_sample_rate': args.token_propagation_sample_rate, + 'tracking_match_propagation_skip_frame': args.tracking_match_propagation_skip_frame + } + + mask_kwargs = { + 'freeze_detr': args.freeze_detr} + + if args.deformable: + transformer = build_deforamble_transformer(args) + + detr_kwargs['transformer'] = transformer + detr_kwargs['num_feature_levels'] = args.num_feature_levels + detr_kwargs['with_box_refine'] = args.with_box_refine + detr_kwargs['two_stage'] = args.two_stage + + if args.tracking: + if args.masks: + model = DeformableDETRSegmTracking(mask_kwargs, tracking_kwargs, detr_kwargs) + else: + model = DeformableDETRTracking(tracking_kwargs, detr_kwargs) + elif args.video_sgg_train: + model = DeformableDETRVsgg(args, detr_kwargs, matcher) + elif args.hoi_detection: + model = DeformableDETRHoi(args, detr_kwargs, matcher) + else: + if args.masks: + model = DeformableDETRSegm(mask_kwargs, detr_kwargs) + else: + model = DeformableDETR(**detr_kwargs) + else: + transformer = build_transformer(args) + + detr_kwargs['transformer'] = transformer + + if args.tracking: + if args.masks: + model = DETRSegmTracking(mask_kwargs, tracking_kwargs, detr_kwargs) + else: + model = DETRTracking(tracking_kwargs, detr_kwargs) + else: + if args.masks: + model = DETRSegm(mask_kwargs, detr_kwargs) + else: + model = DETR(**detr_kwargs) + + weight_dict = {'loss_ce': args.cls_loss_coef, + 'loss_bbox': args.bbox_loss_coef, + 'loss_giou': args.giou_loss_coef,} + + if args.masks: + weight_dict["loss_mask"] = args.mask_loss_coef + weight_dict["loss_dice"] = args.dice_loss_coef + + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality'] + if args.masks: + losses.append('masks') + + if args.hoi_detection: + if args.freeze_detr: weight_dict = {} + weight_dict.update({'loss_relation_proposal': 1, 'loss_relation': 1}) + if args.hoi_aux_loss: + for i in range(args.hoi_dec_layers - 1): + weight_dict.update({f'loss_relation_{i}': weight_dict['loss_relation']}) + + criterion = SetCriterion( + num_classes, + matcher=matcher, + weight_dict=weight_dict, + eos_coef=args.eos_coef, + losses=losses, + track_query_false_positive_eos_weight=args.track_query_false_positive_eos_weight, + focal_loss=args.focal_loss, + focal_alpha=args.focal_alpha, + aux_use_intermediate_match=(not args.tracking and not args.video_sgg_train) + ) + criterion.to(device) + + if args.focal_loss: + postprocessors = {'bbox': DeformablePostProcess()} + else: + postprocessors = {'bbox': PostProcess()} + if args.masks: + postprocessors['segm'] = PostProcessSegm() + if args.dataset == "coco_panoptic": + is_thing_map = {i: i <= 90 for i in range(201)} + postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) + + return model, criterion, postprocessors diff --git a/src/trackformer/models/backbone.py b/src/trackformer/models/backbone.py new file mode 100644 index 0000000..d616568 --- /dev/null +++ b/src/trackformer/models/backbone.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from typing import Dict, List + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.feature_pyramid_network import (FeaturePyramidNetwork, + LastLevelMaxPool) + +from ..util.misc import NestedTensor, is_main_process +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, + return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if (not train_backbone + or 'layer2' not in name + and 'layer3' not in name + and 'layer4' not in name): + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + self.strides = [4, 8, 16, 32] + self.num_channels = [256, 512, 1024, 2048] + else: + return_layers = {'layer4': "0"} + self.strides = [32] + self.num_channels = [2048] + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + norm_layer = FrozenBatchNorm2d + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=norm_layer) + super().__init__(backbone, train_backbone, + return_interm_layers) + if dilation: + self.strides[-1] = self.strides[-1] // 2 + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for x in xs.values(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks or (args.num_feature_levels > 1) + backbone = Backbone(args.backbone, + train_backbone, + return_interm_layers, + args.dilation) + model = Joiner(backbone, position_embedding) + return model diff --git a/src/trackformer/models/deformable_detr.py b/src/trackformer/models/deformable_detr.py new file mode 100644 index 0000000..4d55b61 --- /dev/null +++ b/src/trackformer/models/deformable_detr.py @@ -0,0 +1,270 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Deformable DETR model and criterion classes. +""" +import copy +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from ..util import box_ops +from ..util.misc import (NestedTensor, accuracy, get_world_size, + inverse_sigmoid, is_dist_avail_and_initialized, + nested_tensor_from_tensor_list, sigmoid_focal_loss) +from .detr import DETR, PostProcess, SetCriterion + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DeformableDETR(DETR): + """ This is the Deformable DETR module that performs object detection """ + def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, + aux_loss=True, with_box_refine=False, two_stage=False): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal + number of objects DETR can detect in a single image. For COCO, + we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + """ + super().__init__(backbone, transformer, num_classes, num_queries, aux_loss) + + self.num_feature_levels = num_feature_levels + if not two_stage: + self.query_embed = nn.Embedding(num_queries, self.hidden_dim * 2) + num_channels = backbone.num_channels[-3:] + if num_feature_levels > 1: + # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + num_backbone_outs = len(backbone.strides) - 1 + + input_proj_list = [] + for i in range(num_backbone_outs): + in_channels = num_channels[i] + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1), + nn.GroupNorm(32, self.hidden_dim), + )) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, self.hidden_dim), + )) + in_channels = self.hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(num_channels[0], self.hidden_dim, kernel_size=1), + nn.GroupNorm(32, self.hidden_dim), + )]) + self.with_box_refine = with_box_refine + self.two_stage = two_stage + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones_like(self.class_embed.bias) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for + # region proposal generation + num_pred = transformer.decoder.num_layers + if two_stage: + num_pred += 1 + + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + + # @property + # def fpn_channels(self): + # """ Returns FPN channels. """ + # num_backbone_outs = len(self.backbone.strides) + # return [self.hidden_dim, ] * num_backbone_outs + + def forward(self, samples: NestedTensor, targets: list = None): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + features_all = features + # pos_all = pos + # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + features = features[-3:] + pos = pos[-3:] + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + query_embeds = None + if not self.two_stage: + query_embeds = self.query_embed.weight + hs, memory, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = \ + self.transformer(srcs, masks, pos, query_embeds, targets) + + outputs_classes = [] + outputs_coords = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + out = {'pred_logits': outputs_class[-1], + 'pred_boxes': outputs_coord[-1], + 'hs_embed': hs[-1]} + + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) + + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord} + + offset = 0 + memory_slices = [] + batch_size, _, channels = memory.shape + for src in srcs: + _, _, height, width = src.shape + memory_slice = memory[:, offset:offset + height * width].permute(0, 2, 1).view( + batch_size, channels, height, width) + memory_slices.append(memory_slice) + offset += height * width + + memory = memory_slices + # memory = memory_slices[-1] + # features = [NestedTensor(memory_slide) for memory_slide in memory_slices] + + return out, targets, features_all, memory, hs + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class DeformablePostProcess(PostProcess): + """ This module converts the model's output into the format expected by the coco api""" + + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + + ### + # topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) + # scores = topk_values + + # topk_boxes = topk_indexes // out_logits.shape[2] + # labels = topk_indexes % out_logits.shape[2] + + # boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + # boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) + ### + + scores, labels = prob.max(-1) + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [ + {'scores': s, 'scores_no_object': 1 - s, 'labels': l, 'boxes': b} + for s, l, b in zip(scores, labels, boxes)] + + return results diff --git a/src/trackformer/models/deformable_transformer.py b/src/trackformer/models/deformable_transformer.py new file mode 100644 index 0000000..9ce9789 --- /dev/null +++ b/src/trackformer/models/deformable_transformer.py @@ -0,0 +1,434 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import math + +import torch +from torch import nn +from torch.nn.init import constant_, normal_, xavier_uniform_ + +from ..util.misc import inverse_sigmoid +from .ops.modules import MSDeformAttn +from .transformer import _get_clones, _get_activation_fn +import matplotlib.pyplot as plt + +class DeformableTransformer(nn.Module): + def __init__(self, args, d_model=256, nhead=8, + num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, + dropout=0.1, activation="relu", return_intermediate_dec=False, + num_feature_levels=4, dec_n_points=4, enc_n_points=4, + two_stage=False, two_stage_num_proposals=300): + super().__init__() + self.args = args + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + + encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, enc_n_points) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, dec_n_points) + self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec) + + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + else: + self.reference_points = nn.Linear(d_model, 2) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if not self.two_stage: + xavier_uniform_(self.reference_points.weight.data, gain=1.0) + constant_(self.reference_points.bias.data, 0.) + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds, query_embed=None, targets=None): + assert self.two_stage or query_embed is not None + + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # encoder + memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) + + # prepare input for decoder + bs, _, c = memory.shape + if self.two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_embed, tgt = torch.split(pos_trans_out, c, dim=2) + else: + query_embed, tgt = torch.split(query_embed, c, dim=1) + query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + + reference_points = self.reference_points(query_embed).sigmoid() + + if targets is not None and 'track_query_hs_embeds' in targets[0]: + prev_hs_embed = torch.stack([t['track_query_hs_embeds'] for t in targets]) + + if self.args.track_query_propagation_strategy == 'consistent_pairing': + assert prev_hs_embed.shape == query_embed.shape + if self.args.tracking_token_propagation: + prop_mask = torch.stack([t['track_token_propagation_mask'] for t in targets]).unsqueeze(-1) + + prev_hs_embed_tgt_mapping = torch.stack([t['track_query_hs_embeds_tgt_mapping'] for t in targets]) + # tgt = (1-prop_mask) * tgt + prop_mask * prev_hs_embed_tgt_mapping # query tgt + tgt = (1-prop_mask/2) * tgt + prop_mask/2 * prev_hs_embed_tgt_mapping + + prev_hs_embed_pos_mapping = torch.stack([t['track_query_hs_embeds_pos_mapping'] for t in targets]) + # query_embed = query_embed + prop_mask * prev_hs_embed_pos_mapping + query_embed = (1-prop_mask/2) * query_embed + prop_mask/2 * prev_hs_embed_pos_mapping + + prev_boxes = torch.stack([t['track_query_boxes'] for t in targets]) + reference_points = (1-prop_mask) * reference_points + prop_mask * prev_boxes[..., :2] # ref points + + # # visualize reference points + # norm_ref_points = reference_points[0].cpu() + # plt.scatter(norm_ref_points[:, 0][prop_mask.cpu().squeeze()==0], (1-norm_ref_points[:, 1][prop_mask.cpu().squeeze()==0])) + # plt.scatter(norm_ref_points[:, 0][prop_mask.cpu().squeeze()!=0], (1-norm_ref_points[:, 1][prop_mask.cpu().squeeze()!=0]), marker='+', c='red') + # plt.show() + elif self.args.track_query_propagation_strategy == 'trackformer': + if self.args.tracking_token_propagation: + prev_boxes = torch.stack([t['track_query_boxes'] for t in targets]) + query_embed = torch.cat([torch.zeros_like(prev_hs_embed), query_embed], dim=1) + tgt = torch.cat([prev_hs_embed, tgt], dim=1) + # query_embed = torch.cat([prev_hs_embed, query_embed], dim=1) + # tgt = torch.cat([torch.zeros_like(prev_hs_embed), tgt], dim=1) + reference_points = torch.cat([prev_boxes[..., :2], reference_points], dim=1) + + # # visualize reference points + # norm_ref_points = reference_points[0].cpu() + # plt.scatter(norm_ref_points[prev_boxes.shape[1]:, 0], 1-norm_ref_points[prev_boxes.shape[1]:, 1]) + # plt.scatter(norm_ref_points[:prev_boxes.shape[1], 0], 1-norm_ref_points[:prev_boxes.shape[1], 1], marker='+', c='red') + # plt.show() + + init_reference_out = reference_points + + # decoder + hs, inter_references = self.decoder( + tgt, reference_points, memory, spatial_shapes, level_start_index, + valid_ratios, query_embed, mask_flatten) + + inter_references_out = inter_references + + # offset = 0 + # memory_slices = [] + # for src in srcs: + # _, _, height, width = src.shape + # memory_slice = memory[:, offset:offset +h * w].permute(0, 2, 1).view( + # bs, c, height, width) + # memory_slices.append(memory_slice) + # offset += h * w + + # # memory = memory_slices[-1] + # print([m.shape for m in memory_slices]) + + if self.two_stage: + return (hs, memory, init_reference_out, inter_references_out, + enc_outputs_class, enc_outputs_coord_unact) + return hs, memory, init_reference_out, inter_references_out, None, None + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__(self, + d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): + # self attention + src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): + output = src + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) + for _, layer in enumerate(self.layers): + output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) + + return output + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__(self, d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), + reference_points, + src, src_spatial_shapes, level_start_index, src_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios, + query_pos=None, src_padding_mask=None): + output = tgt + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] \ + * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] + output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +def build_deforamble_transformer(args): + return DeformableTransformer( + args, + d_model=args.hidden_dim, + nhead=args.nheads, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=args.num_feature_levels, + dec_n_points=args.dec_n_points, + enc_n_points=args.enc_n_points, + two_stage=args.two_stage, + two_stage_num_proposals=args.num_queries) + + diff --git a/src/trackformer/models/detr.py b/src/trackformer/models/detr.py new file mode 100644 index 0000000..cdfd140 --- /dev/null +++ b/src/trackformer/models/detr.py @@ -0,0 +1,511 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn +import copy +from ..util import box_ops +from ..util.misc import (NestedTensor, accuracy, dice_loss, get_world_size, + interpolate, is_dist_avail_and_initialized, + nested_tensor_from_tensor_list, sigmoid_focal_loss) +import math + +class DETR(nn.Module): + """ This is the DETR module that performs object detection. """ + + def __init__(self, backbone, transformer, num_classes, num_queries, + aux_loss=False): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal + number of objects DETR can detect in a single image. For COCO, we + recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + + self.num_queries = num_queries + self.transformer = transformer + self.class_embed = nn.Linear(self.hidden_dim, num_classes + 1) + self.bbox_embed = MLP(self.hidden_dim, self.hidden_dim, 4, 3) + self.query_embed = nn.Embedding(num_queries, self.hidden_dim) + + # match interface with deformable DETR + self.input_proj = nn.Conv2d(backbone.num_channels[-1], self.hidden_dim, kernel_size=1) + # self.input_proj = nn.ModuleList([ + # nn.Sequential( + # nn.Conv2d(backbone.num_channels[-1], self.hidden_dim, kernel_size=1) + # )]) + + self.backbone = backbone + self.aux_loss = aux_loss + + @property + def hidden_dim(self): + """ Returns the hidden feature dimension size. """ + return self.transformer.d_model + + @property + def fpn_channels(self): + """ Returns FPN channels. """ + return self.backbone.num_channels[:3][::-1] + # return [1024, 512, 256] + + def forward(self, samples: NestedTensor, targets: list = None): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], + containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized + in [0, 1], relative to the size of each individual image + (disregarding possible padding). See PostProcess for information + on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It + is a list of dictionnaries containing the two above keys for + each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + src, mask = features[-1].decompose() + # src = self.input_proj[-1](src) + src = self.input_proj(src) + pos = pos[-1] + + batch_size, _, _, _ = src.shape + + query_embed = self.query_embed.weight + query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1) + tgt = None + if targets is not None and 'track_query_hs_embeds' in targets[0]: + # [BATCH_SIZE, NUM_PROBS, 4] + track_query_hs_embeds = torch.stack([t['track_query_hs_embeds'] for t in targets]) + + num_track_queries = track_query_hs_embeds.shape[1] + + track_query_embed = torch.zeros( + num_track_queries, + batch_size, + self.hidden_dim).to(query_embed.device) + query_embed = torch.cat([ + track_query_embed, + query_embed], dim=0) + + tgt = torch.zeros_like(query_embed) + tgt[:num_track_queries] = track_query_hs_embeds.transpose(0, 1) + + for i, target in enumerate(targets): + target['track_query_hs_embeds'] = tgt[:, i] + + assert mask is not None + hs, hs_without_norm, memory = self.transformer( + src, mask, query_embed, pos, tgt) + + outputs_class = self.class_embed(hs) + outputs_coord = self.bbox_embed(hs).sigmoid() + out = {'pred_logits': outputs_class[-1], + 'pred_boxes': outputs_coord[-1], + 'hs_embed': hs_without_norm[-1]} + + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss( + outputs_class, outputs_coord) + + return out, targets, features, memory, hs + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, + track_query_false_positive_eos_weight, focal_loss, focal_alpha, aux_use_intermediate_match=True): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their + relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of + available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + self.track_query_false_positive_eos_weight = track_query_false_positive_eos_weight + self.focal_loss = focal_loss + self.focal_alpha = focal_alpha + + self.aux_use_intermediate_match = aux_use_intermediate_match + + def loss_labels(self, outputs, targets, indices, _, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + # target_classes[torch.stack([t["track_queries_match_mask"] == -1.0 for t in targets])] = 1 + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), + target_classes, + weight=self.empty_weight, + reduction='none') + + if self.track_query_false_positive_eos_weight: + for i, target in enumerate(targets): + if 'track_query_boxes' in target: + # remove no-object weighting for false track_queries + loss_ce[i, targets[i]['track_queries_match_mask'] == -1] *= 1 / self.eos_coef + # assign false track_queries to some object class for the final weighting + target_classes = target_classes.clone() + target_classes[i, targets[i]['track_queries_match_mask'] == -1] = 0 + + # ## hardcode: emphasize on detecting new objects + # new_obj_query_mask = (targets[i]['track_queries_match_mask'][indices[i][0]] == 0) + # new_obj_queries = indices[i][0][new_obj_query_mask] + # if len(new_obj_queries) > 1: + # loss_ce[i, new_obj_queries] *= 10 + + loss_ce = loss_ce.sum() / self.empty_weight[target_classes].sum() + + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:,:,:-1] + + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of + predicted non-empty boxes. This is not really a loss, it is intended + for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss + and the GIoU loss targets dicts must contain the key "boxes" containing + a tensor of dim [nb_target_boxes, 4]. The target boxes are expected in + format (center_x, center_y, h, w), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of + dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, _ = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels_focal if self.focal_loss else self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, + see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor( + [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the + # output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + if self.aux_use_intermediate_match: ## else use match of final prediction + indices = self.matcher(aux_outputs, targets) + + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'enc_outputs' in outputs: + enc_outputs = outputs['enc_outputs'] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt['labels'] = torch.zeros_like(bt['labels']) + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_enc': v for k, v in l_dict.items()} + losses.update(l_dict) + + # for relation losses + if ('pred_rel_pairs' in outputs_without_aux) and ('pred_relations' in outputs_without_aux): + all_rel_pair_targets = [] + for imgid, (tgt, (det_idxs, gtbox_idxs)) in enumerate(zip(targets, indices)): + det2gt_map = {int(d): int(g) for d, g in zip(det_idxs, gtbox_idxs)} + gt_relation_map = tgt['relation_map'] + rel_pairs = outputs['pred_rel_pairs'][imgid] + rel_pair_targets = torch.zeros((len(rel_pairs), gt_relation_map.shape[-1])).to(gt_relation_map.device) + for idx, rel in enumerate(rel_pairs): + if (int(rel[0]) in det2gt_map) and (int(rel[1]) in det2gt_map): + rel_pair_targets[idx] = gt_relation_map[det2gt_map[int(rel[0])], det2gt_map[int(rel[1])]] + all_rel_pair_targets.append(rel_pair_targets) + all_rel_pair_targets = torch.cat(all_rel_pair_targets, dim=0) + + losses.update({ + 'loss_relation_proposal': self.relation_proposal_loss(torch.cat(outputs['pred_relation_exists'], dim=0), (all_rel_pair_targets.sum(-1) > 0).float()), + 'loss_relation': self.relation_loss(torch.cat(outputs['pred_relations'], dim=0), all_rel_pair_targets) + }) + if 'relation_aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['relation_aux_outputs']): + losses.update({ + f'loss_relation_{i}': self.relation_loss(torch.cat(aux_outputs['pred_relations'], dim=0), all_rel_pair_targets) + }) + + return losses + + def relation_proposal_loss(self, inputs, targets, gamma=2): + # focal loss to balance positive/negative + probs = inputs.sigmoid() * 0.9999 # for stability + pos_inds = targets.eq(1).float() + neg_inds = targets.lt(1).float() + pos_loss = torch.log(probs) * torch.pow(1 - probs, gamma) * pos_inds + neg_loss = torch.log(1 - probs) * torch.pow(probs, gamma) * neg_inds + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + # normalize + num_pos = pos_inds.float().sum() + if num_pos == 0: + loss = -neg_loss + else: + loss = -(pos_loss + neg_loss) / num_pos + return loss + + def relation_loss(self, inputs, targets, gamma=2): + probs = inputs.sigmoid() * 0.9999 + 1e-8 # for stability + + # focal loss to balance hard and easy + pos_inds = targets.eq(1).float() + neg_inds = targets.lt(1).float() + pos_loss = torch.log(probs) * torch.pow(1 - probs, gamma) * pos_inds + neg_loss = torch.log(1 - probs) * torch.pow(probs, gamma) * neg_inds + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + # normalize + num_pos = pos_inds.float().sum() + if num_pos == 0: + loss = -neg_loss + else: + loss = -(pos_loss + neg_loss) / num_pos + + return loss + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + def process_boxes(self, boxes, target_sizes): + # convert to [x0, y0, x1, y1] format + boxes = box_ops.box_cxcywh_to_xyxy(boxes) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + return boxes + + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of + each images of the batch For evaluation, this must be the + original image size (before any data augmentation) For + visualization, this should be the image size after data + augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = F.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + boxes = self.process_boxes(out_bbox, target_sizes) + + results = [ + {'scores': s, 'labels': l, 'boxes': b, 'scores_no_object': s_n_o} + for s, l, b, s_n_o in zip(scores, labels, boxes, prob[..., -1])] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) + for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/src/trackformer/models/detr_hoi.py b/src/trackformer/models/detr_hoi.py new file mode 100644 index 0000000..3eb13e9 --- /dev/null +++ b/src/trackformer/models/detr_hoi.py @@ -0,0 +1,203 @@ +from .deformable_detr import DeformableDETR +from ..util.misc import NestedTensor +from ..util import box_ops +import torch +import torch.nn as nn +from .transformer import TransformerDecoderLayer, TransformerDecoder +from .position_encoding import build_position_encoding +from collections import OrderedDict +import torchvision + +class DeformableDETRHoi(DeformableDETR): + def __init__(self, args, detr_kwargs, matcher): + DeformableDETR.__init__(self, **detr_kwargs) + + self.args = args + self._matcher = matcher + + if args.freeze_detr: + for p in self.parameters(): + p.requires_grad_(False) + + # instance representation + if self.args.hoi_instance_fuse_spatial_and_semantic_feat: + self.spatial_dim, self.semantic_dim = 128, 128 + self.spatial_embed = nn.Linear(4, self.spatial_dim) + self.semantic_embed = nn.Embedding(args.num_classes+1, self.semantic_dim) + self.instance_representation_fuse = nn.Sequential( + nn.Linear(self.args.hidden_dim+self.spatial_dim+self.semantic_dim, self.args.hidden_dim), nn.ReLU() + ) + + # for rel/interactions prediction + rel_rep_dim = self.args.hidden_dim * 2 + self.relation_proposal_mlp = nn.Sequential( + nn.Linear(rel_rep_dim, rel_rep_dim // 2), nn.ReLU(), + nn.Linear(rel_rep_dim // 2, 1) + ) + self.rel_query_pre_proj = nn.Linear(rel_rep_dim, self.args.hidden_dim) + + rel_dec_hidden_dim = self.args.hidden_dim + self.memory_input_proj = nn.Conv2d(2048, self.args.hidden_dim, kernel_size=1) + self.rel_memory_pos = build_position_encoding(args) + + decoder_layer = TransformerDecoderLayer(d_model=rel_dec_hidden_dim, nhead=8) + decoder_norm = nn.LayerNorm(rel_dec_hidden_dim) + if self.args.hoi_use_interaction_decoder: + self.interaction_decoder = TransformerDecoder(decoder_layer, None, self.args.hoi_dec_layers, decoder_norm, return_intermediate=True) + self.relation_embed = nn.Linear(rel_dec_hidden_dim, self.args.num_relations) + + if self.args.hoi_oracle_mode and self.args.hoi_oracle_mode_use_roialign_union_feat: + self.fpn = torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork(in_channels_list=[256, 512, 1024, 2048], out_channels=256) + self.box_pooler = torchvision.ops.MultiScaleRoIAlign(['0', '1', '2', '3'], 7, sampling_ratio=2) + self.box_pool_fc = nn.Sequential(nn.Linear(256*7*7, self.args.hidden_dim), nn.ReLU()) + self.union_pool_fc = nn.Linear(self.args.hidden_dim*3, self.args.hidden_dim*2) + + def forward(self, samples: NestedTensor, targets: list = None): + outs, _, features_all, _, _ = super().forward(samples, targets) + + # memory input for relation transformer decoder + memory_input_feature, memory_input_mask = features_all[-1].decompose() + memory_pos = self.rel_memory_pos(features_all[-1]) + memory_input = self.memory_input_proj(memory_input_feature) + + det2gt_indices = None + if self.training or self.args.hoi_oracle_mode: + det2gt_indices = self._matcher(outs, targets) + gt_rel_pairs = [] + for idx, ((ds, gs), t) in enumerate(zip(det2gt_indices, targets)): + gt2det_map = torch.zeros(len(gs)).to(device=ds.device, dtype=ds.dtype) + gt2det_map[gs] = ds + gt_rels = gt2det_map[t['relation_map'].sum(-1).nonzero(as_tuple=False)] + gt_rel_pairs.append(gt_rels) + + if self.args.hoi_oracle_mode: + outs['pred_logits'][idx, :, -1] = 1e3 # set default class as background + outs['pred_logits'][idx, ds, t['labels'][gs]] = 1e6 + outs['pred_boxes'][idx, ds] = t['boxes'][gs] + if 'aux_outputs' in outs: + for o in outs['aux_outputs']: + o['pred_logits'][idx, :, -1] = 1e3 + o['pred_logits'][idx, ds, t['labels'][gs]] = 1e6 + o['pred_boxes'][idx, ds] = t['boxes'][gs] + + pred_relation_exists, pred_rel_pairs, pred_relations = [], [], [] + bs, num_nodes = samples.tensors.shape[0], self.args.num_queries + for imgid in range(bs): + # >>>>>>>>>>>> relation proposal <<<<<<<<<<<<<<< + probs = outs['pred_logits'][imgid].softmax(-1) + inst_scores, inst_labels = probs[:, :-1].max(-1) + human_instance_ids = torch.logical_and(inst_scores>0.1, inst_labels==0).nonzero(as_tuple=False) # class0: person + + rel_mat = torch.zeros((num_nodes, num_nodes)) + rel_mat[human_instance_ids] = 1 + if self.args.hoi_oracle_mode: + gt_mask = torch.zeros_like(rel_mat) + gt_mask[det2gt_indices[imgid][0]] += 1; gt_mask[:, det2gt_indices[imgid][0]] += 1 + rel_mat[gt_mask!=2] = 0 + + if self.training: + if self.args.hoi_oracle_mode: + rel_mat[gt_rel_pairs[imgid][:, :1], det2gt_indices[imgid][0]] = 1 + else: + rel_mat[gt_rel_pairs[imgid][:, :1]] = 1 + rel_mat[gt_rel_pairs[imgid][:, 0], gt_rel_pairs[imgid][:, 1]] = 0 + rel_mat.fill_diagonal_(0) + rel_pairs = rel_mat.nonzero(as_tuple=False) # neg pairs + + if self.args.hoi_hard_mining: + all_pairs = torch.cat([gt_rel_pairs[imgid], rel_pairs], dim=0) + gt_pair_count = len(gt_rel_pairs[imgid]) + all_rel_reps = self._build_relation_representations(outs, all_pairs, imgid, features_all=features_all, image_size=samples.tensors.shape[-2:]) + p_relation_exist_logits = self.relation_proposal_mlp(all_rel_reps) + + gt_inds = torch.arange(gt_pair_count).to(p_relation_exist_logits.device) + # _, sort_rel_inds = p_relation_exist_logits[gt_pair_count:].squeeze(1).sort(descending=True) + _, sort_rel_inds = torch.cat([inst_scores[all_pairs], p_relation_exist_logits.sigmoid()], dim=-1).prod(-1)[gt_pair_count:].sort(descending=True) + sampled_rel_inds = torch.cat([gt_inds, sort_rel_inds+gt_pair_count])[:self.args.num_hoi_queries] + + sampled_rel_pairs = all_pairs[sampled_rel_inds] + sampled_rel_reps = all_rel_reps[sampled_rel_inds] + sampled_rel_pred_exists = p_relation_exist_logits.squeeze(1)[sampled_rel_inds] + else: + sampled_neg_inds = torch.randperm(len(rel_pairs)) # random sampling + sampled_rel_pairs = torch.cat([gt_rel_pairs[imgid], rel_pairs[sampled_neg_inds]], dim=0)[:self.args.num_hoi_queries] + sampled_rel_reps = self._build_relation_representations(outs, sampled_rel_pairs, imgid, features_all=features_all, image_size=samples.tensors.shape[-2:]) + sampled_rel_pred_exists = self.relation_proposal_mlp(sampled_rel_reps).squeeze(1) + else: + rel_mat.fill_diagonal_(0) + rel_pairs = rel_mat.nonzero(as_tuple=False) + rel_reps = self._build_relation_representations(outs, rel_pairs, imgid, features_all=features_all, image_size=samples.tensors.shape[-2:]) + p_relation_exist_logits = self.relation_proposal_mlp(rel_reps) + + # _, sort_rel_inds = p_relation_exist_logits.squeeze(1).sort(descending=True) + _, sort_rel_inds = torch.cat([inst_scores[rel_pairs], p_relation_exist_logits.sigmoid()], dim=-1).prod(-1).sort(descending=True) + sampled_rel_inds = sort_rel_inds[:self.args.num_hoi_queries] + + sampled_rel_pairs = rel_pairs[sampled_rel_inds] + sampled_rel_reps = rel_reps[sampled_rel_inds] + sampled_rel_pred_exists = p_relation_exist_logits.squeeze(1)[sampled_rel_inds] + + # >>>>>>>>>>>> relation classification <<<<<<<<<<<<<<< + query_reps = self.rel_query_pre_proj(sampled_rel_reps).unsqueeze(1) + if self.args.hoi_use_interaction_decoder: + relation_outs, _ = self.interaction_decoder(tgt=query_reps, + memory=memory_input[imgid:imgid+1].flatten(2).permute(2,0,1), + memory_key_padding_mask=memory_input_mask[imgid:imgid+1].flatten(1), + pos=memory_pos[imgid:imgid+1].flatten(2).permute(2, 0, 1)) + else: + relation_outs = query_reps.unsqueeze(0) + relation_logits = self.relation_embed(relation_outs) + + pred_rel_pairs.append(sampled_rel_pairs) + pred_relations.append(relation_logits) + pred_relation_exists.append(sampled_rel_pred_exists) + + outs.update({ + "pred_rel_pairs": pred_rel_pairs, + "pred_relations": [p[-1].squeeze(1) for p in pred_relations], + "pred_relation_exists": pred_relation_exists, + "det2gt_indices": det2gt_indices, + }) + + if self.args.hoi_aux_loss: + outs['relation_aux_outputs'] = self._set_hoi_aux_loss(pred_relations) + + return outs, targets, None, None, None + + @torch.jit.unused + def _set_hoi_aux_loss(self, pred_relations): + return [{'pred_relations': [p[l].squeeze(1) for p in pred_relations]} for l in range(self.args.hoi_dec_layers - 1)] + + def _build_relation_representations(self, outs, rel_pairs, imgid, features_all=None, image_size=None): + inst_reps = outs['hs_embed'][imgid] + + if self.args.hoi_instance_fuse_spatial_and_semantic_feat: + inst_spatial_reps = self.spatial_embed(outs['pred_boxes'][imgid]) + inst_semantic_reps = outs['pred_logits'][imgid].softmax(-1) @ self.semantic_embed.weight + inst_reps = self.instance_representation_fuse(torch.cat([inst_reps, inst_spatial_reps, inst_semantic_reps], dim=-1)) + + rel_reps = torch.cat([inst_reps[rel_pairs[:, 0]], inst_reps[rel_pairs[:, 1]]], dim=1) + + # fuse roi_align union feature + if self.args.hoi_oracle_mode and self.args.hoi_oracle_mode_use_roialign_union_feat: + feat_order_dict = OrderedDict() + for lvl, feat in enumerate(features_all): + feat_order_dict[str(lvl)] = feat.tensors + fpn_feats = self.fpn(feat_order_dict) + + xyxy_boxes = box_ops.box_cxcywh_to_xyxy(outs['pred_boxes'][imgid]) + img_h, img_w = image_size + scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(xyxy_boxes.device) + xyxy_boxes = xyxy_boxes * scale_fct[None, :] + subj_boxes, obj_boxes = xyxy_boxes[rel_pairs[:, 0]], xyxy_boxes[rel_pairs[:, 1]] + union_boxes = torch.cat([torch.min(subj_boxes[:, :2], obj_boxes[:, :2]), torch.max(subj_boxes[:, 2:], obj_boxes[:, 2:])], dim=-1) + union_pool_feats = self.box_pool_fc( + self.box_pooler(fpn_feats, [union_boxes], [image_size]).view(-1, 256*7*7) + ) + rel_reps = self.union_pool_fc(torch.cat([rel_reps, union_pool_feats], dim=-1)) + + return rel_reps + + def tracking(self): + """Compatible with vsgg eval""" + self.eval() diff --git a/src/trackformer/models/detr_segmentation.py b/src/trackformer/models/detr_segmentation.py new file mode 100644 index 0000000..130cb8e --- /dev/null +++ b/src/trackformer/models/detr_segmentation.py @@ -0,0 +1,385 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +This file provides the definition of the convolutional heads used +to predict masks, as well as the losses. +""" +import io +from collections import defaultdict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torch import Tensor + +from ..util import box_ops +from ..util.misc import NestedTensor, interpolate + +try: + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + pass + +from .deformable_detr import DeformableDETR +from .detr import DETR +from .detr_tracking import DETRTrackingBase + + +class DETRSegmBase(nn.Module): + def __init__(self, freeze_detr=False): + if freeze_detr: + for param in self.parameters(): + param.requires_grad_(False) + + nheads = self.transformer.nhead + self.bbox_attention = MHAttentionMap(self.hidden_dim, self.hidden_dim, nheads, dropout=0.0) + + self.mask_head = MaskHeadSmallConv( + self.hidden_dim + nheads, self.fpn_channels, self.hidden_dim) + + def forward(self, samples: NestedTensor, targets: list = None): + out, targets, features, memory, hs = super().forward(samples, targets) + + if isinstance(memory, list): + src, mask = features[-2].decompose() + batch_size = src.shape[0] + + src = self.input_proj[-3](src) + mask = F.interpolate(mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + + # fpns = [memory[2], memory[1], memory[0]] + fpns = [features[-2].tensors, features[-3].tensors, features[-4].tensors] + memory = memory[-3] + else: + src, mask = features[-1].decompose() + batch_size = src.shape[0] + + src = self.input_proj(src) + + fpns = [features[2].tensors, features[1].tensors, features[0].tensors] + + # FIXME h_boxes takes the last one computed, keep this in mind + bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) + + seg_masks = self.mask_head(src, bbox_mask, fpns) + outputs_seg_masks = seg_masks.view( + batch_size, hs.shape[2], seg_masks.shape[-2], seg_masks.shape[-1]) + + out["pred_masks"] = outputs_seg_masks + + return out, targets, features, memory, hs + + +# TODO: with meta classes +class DETRSegm(DETRSegmBase, DETR): + def __init__(self, mask_kwargs, detr_kwargs): + DETR.__init__(self, **detr_kwargs) + DETRSegmBase.__init__(self, **mask_kwargs) + + +class DeformableDETRSegm(DETRSegmBase, DeformableDETR): + def __init__(self, mask_kwargs, detr_kwargs): + DeformableDETR.__init__(self, **detr_kwargs) + DETRSegmBase.__init__(self, **mask_kwargs) + + +class DETRSegmTracking(DETRSegmBase, DETRTrackingBase, DETR): + def __init__(self, mask_kwargs, tracking_kwargs, detr_kwargs): + DETR.__init__(self, **detr_kwargs) + DETRTrackingBase.__init__(self, **tracking_kwargs) + DETRSegmBase.__init__(self, **mask_kwargs) + + +class DeformableDETRSegmTracking(DETRSegmBase, DETRTrackingBase, DeformableDETR): + def __init__(self, mask_kwargs, tracking_kwargs, detr_kwargs): + DeformableDETR.__init__(self, **detr_kwargs) + DETRTrackingBase.__init__(self, **tracking_kwargs) + DETRSegmBase.__init__(self, **mask_kwargs) + + +def _expand(tensor, length: int): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + +class MaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. + Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + inter_dims = [ + dim, + context_dim // 2, + context_dim // 4, + context_dim // 8, + context_dim // 16, + context_dim // 64] + self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, dim) + self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): + x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + +class MHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns + the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + nn.init.zeros_(self.k_linear.bias) + nn.init.zeros_(self.q_linear.bias) + nn.init.xavier_uniform_(self.k_linear.weight) + nn.init.xavier_uniform_(self.q_linear.weight) + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask: Optional[Tensor] = None): + q = self.q_linear(q) + k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + kh = k.view( + k.shape[0], + self.num_heads, + self.hidden_dim // self.num_heads, + k.shape[-2], + k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) + + if mask is not None: + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) + weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) + weights = self.dropout(weights) + return weights + + +class PostProcessSegm(nn.Module): + def __init__(self, threshold=0.5): + super().__init__() + self.threshold = threshold + + @torch.no_grad() + def forward(self, results, outputs, orig_target_sizes, max_target_sizes, return_probs=False): + assert len(orig_target_sizes) == len(max_target_sizes) + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs["pred_masks"].squeeze(2) + outputs_masks = F.interpolate( + outputs_masks, + size=(max_h, max_w), + mode="bilinear", + align_corners=False) + + outputs_masks = outputs_masks.sigmoid().cpu() + if not return_probs: + outputs_masks = outputs_masks > self.threshold + + zip_iter = zip(outputs_masks, max_target_sizes, orig_target_sizes) + for i, (cur_mask, t, tt) in enumerate(zip_iter): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = F.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ) + + if not return_probs: + results[i]["masks"] = results[i]["masks"].byte() + + return results + + +class PostProcessPanoptic(nn.Module): + """This class converts the output of the model to the final panoptic result, + in the format expected by the coco panoptic API """ + + def __init__(self, is_thing_map, threshold=0.85): + """ + Parameters: + is_thing_map: This is a whose keys are the class ids, and the values + a boolean indicating whether the class is a thing (True) + or a stuff (False) class + threshold: confidence threshold: segments with confidence lower than + this will be deleted + """ + super().__init__() + self.threshold = threshold + self.is_thing_map = is_thing_map + + def forward(self, outputs, processed_sizes, target_sizes=None): + """ This function computes the panoptic prediction from the model's predictions. + Parameters: + outputs: This is a dict coming directly from the model. See the model + doc for the content. + processed_sizes: This is a list of tuples (or torch tensors) of sizes + of the images that were passed to the model, ie the + size after data augmentation but before batching. + target_sizes: This is a list of tuples (or torch tensors) corresponding + to the requested final size of each prediction. If left to + None, it will default to the processed_sizes + """ + if target_sizes is None: + target_sizes = processed_sizes + assert len(processed_sizes) == len(target_sizes) + out_logits, raw_masks, raw_boxes = \ + outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] + assert len(out_logits) == len(raw_masks) == len(target_sizes) + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + scores, labels = cur_logits.softmax(-1).max(-1) + keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) + cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores = cur_scores[keep] + cur_classes = cur_classes[keep] + cur_masks = cur_masks[keep] + cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) + cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + assert len(cur_boxes) == len(cur_classes) + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class + # (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_classes): + if not self.is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) + + np_seg_img = (torch.ByteTensor( + torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()) + m_id = torch.from_numpy(rgb2id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_classes.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor([ + area[i] <= 4 + for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_classes[i].item() + segments_info.append({ + "id": i, + "isthing": self.is_thing_map[cat], + "category_id": cat, + "area": a}) + del cur_classes + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds diff --git a/src/trackformer/models/detr_tracking.py b/src/trackformer/models/detr_tracking.py new file mode 100644 index 0000000..73c3690 --- /dev/null +++ b/src/trackformer/models/detr_tracking.py @@ -0,0 +1,326 @@ +import random + +import torch +import torch.nn as nn + +from ..util import box_ops +from ..util.misc import NestedTensor +from .deformable_detr import DeformableDETR +from .detr import DETR +import numpy as np + +class DETRTrackingBase(nn.Module): + + def __init__(self, + track_query_false_positive_prob=0.0, + track_query_false_negative_prob=0.0, + track_query_noise=0.0, + matcher=None, + track_query_propagation_strategy='trackformer', + tracking_token_propagation=True, + clip_mode=False, + detection_obj_score_thresh=0.9, + token_propagation_sample_rate=0.1, + tracking_match_propagation_skip_frame=False): + self._matcher = matcher + self._track_query_false_positive_prob = track_query_false_positive_prob + self._track_query_false_negative_prob = track_query_false_negative_prob + self._track_query_noise = track_query_noise + + self._tracking = False + + self.track_query_propagation_strategy = track_query_propagation_strategy + self.tracking_token_propagation = tracking_token_propagation + + if self.track_query_propagation_strategy == 'consistent_pairing' and self.tracking_token_propagation: + self.propagation_mlp = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256) + ) + self.clip_mode = clip_mode + self.detection_obj_score_thresh = detection_obj_score_thresh + self.token_propagation_sample_rate = token_propagation_sample_rate + self.tracking_match_propagation_skip_frame = tracking_match_propagation_skip_frame + + def train(self, mode: bool = True): + """Sets the module in train mode.""" + self._tracking = False + return super().train(mode) + + def tracking(self): + """Sets the module in tracking mode.""" + self.eval() + self._tracking = True + + def foward_tracking_inference(self, samples: NestedTensor, targets: list = None): + if self.track_query_propagation_strategy == 'consistent_pairing' and self.tracking_token_propagation and targets is not None: # from the 2nd frame under consistent_pairing + for t in targets: + t['track_query_hs_embeds_mapping'] = self.propagation_mlp(t['track_query_hs_embeds']) # originally prev_outs + + out, targets, features, memory, hs = super().forward(samples, targets) + return out, targets, features, memory, hs + + def foward_train_val_2frame_mode(self, samples: NestedTensor, targets: list = None): + assert self.track_query_propagation_strategy == 'trackformer' + if targets is not None: + prev_out, *_ = super().forward([targets[0]['prev_image']]) ## detr detection forward + + prev_outputs_without_aux = { + k: v for k, v in prev_out.items() if 'aux_outputs' not in k} + prev_targets = [ + {k.replace('prev_', ''): v for k, v in target.items() if "prev" in k} + for target in targets] + prev_indices = self._matcher(prev_outputs_without_aux, prev_targets) + + for i, (target, prev_ind) in enumerate(zip(targets, prev_indices)): + prev_out_ind, prev_target_ind = prev_ind + + # random subset ## 随机擦除上一帧检测到的部分目标 FN Augmentation + if self._track_query_false_negative_prob and self.track_query_propagation_strategy == 'trackformer': + random_subset_mask = torch.empty(len(prev_target_ind)).uniform_() + random_subset_mask = random_subset_mask.ge( + self._track_query_false_negative_prob) + + prev_out_ind = prev_out_ind[random_subset_mask] + prev_target_ind = prev_target_ind[random_subset_mask] + + ## transfer matching from prev to current + # detected prev frame tracks + prev_track_ids = target['prev_track_ids'][prev_target_ind] + + # match track ids between frames + target_ind_match_matrix = prev_track_ids.unsqueeze(dim=1).eq(target['track_ids']) + target_ind_matching = target_ind_match_matrix.any(dim=1) + target_ind_matched_idx = target_ind_match_matrix.nonzero()[:, 1] + + # current frame track ids detected in the prev frame + # track_ids = target['track_ids'][target_ind_matched_idx] + + # index of prev frame detection in current frame box list + target['track_query_match_ids'] = target_ind_matched_idx + + not_prev_out_ind = torch.arange(prev_out['pred_boxes'].shape[1]) + not_prev_out_ind = [ + ind.item() + for ind in not_prev_out_ind + if ind not in prev_out_ind] + random_false_out_ind = [] + + # random false positives + prev_boxes_matched = prev_out['pred_boxes'][i, prev_out_ind[target_ind_matching]] + for prev_box_matched in prev_boxes_matched: + ## 从背景集合,随机挑选噪声 FP Augmentation + if random.uniform(0, 1) < self._track_query_false_positive_prob: + prev_boxes_unmatched = prev_out['pred_boxes'][i, not_prev_out_ind] + + # only cxcy + # box_dists = prev_box_matched[:2].sub(prev_boxes_unmatched[:, :2]).abs() + # box_dists = box_dists.pow(2).sum(dim=-1).sqrt() + # box_weights = 1.0 / box_dists.add(1e-8) + + prev_box_ious, _ = box_ops.box_iou( + box_ops.box_cxcywh_to_xyxy(prev_box_matched.unsqueeze(dim=0)), + box_ops.box_cxcywh_to_xyxy(prev_boxes_unmatched)) + box_weights = prev_box_ious[0] + + if box_weights.gt(0.0).any(): + random_false_out_idx = not_prev_out_ind.pop( + torch.multinomial(box_weights.cpu(), 1).item()) + random_false_out_ind.append(random_false_out_idx) + + prev_out_ind = torch.tensor(prev_out_ind.tolist() + random_false_out_ind).long() + target_ind_matching = torch.tensor( + target_ind_matching.tolist() + [False, ] * len(random_false_out_ind)).bool() + + # matches indices with 1.0 and not matched -1.0 + track_queries_match_mask = torch.ones_like(target_ind_matching).float() + track_queries_match_mask[~target_ind_matching] = -1.0 + + # set prev frame info + hs_embeds = prev_out['hs_embed'][i, prev_out_ind] + if self._track_query_noise and not torch.isnan(hs_embeds.std()).any(): + track_query_noise = torch.randn_like(hs_embeds) \ + * hs_embeds.std(dim=1, keepdim=True) + hs_embeds = hs_embeds + track_query_noise * self._track_query_noise + # hs_embeds = track_query_noise * self._track_query_noise \ + # + hs_embeds * (1 - self._track_query_noise) + target['track_query_hs_embeds'] = hs_embeds + target['track_query_boxes'] = prev_out['pred_boxes'][i, prev_out_ind].detach() + + # add zeros for detection object queries + device = track_queries_match_mask.device + track_queries_match_mask = torch.tensor( + track_queries_match_mask.tolist() + [0, ] * self.num_queries) + + target['track_queries_match_mask'] = track_queries_match_mask.to(device) + + out, targets, features, memory, hs = super().forward(samples, targets) + return out, targets, features, memory, hs + + def foward_train_val_clip_mode(self, samples: NestedTensor, targets: list = None): + outs = [] + for frame_id, frame_target in enumerate(targets): + frame_image = samples.tensors[frame_id] + out, *_ = super().forward([frame_image], [frame_target]) + outs.append(out) + + # propagate to the next frame + if frame_id < len(targets)-1: + prev_out = out + target = targets[frame_id+1] # target for the next frame + prev_indices = self._matcher(prev_out, [frame_target]) + prev_out_ind, prev_target_ind = prev_indices[0] + prev_tid2qid = {int(tid): int(qid) for tid, qid in zip(frame_target['track_ids'][prev_target_ind], prev_out_ind)} + + # random subset ## 随机擦除上一帧检测到的部分目标 FN Augmentation + if self._track_query_false_negative_prob: + random_subset_mask = torch.empty(len(prev_target_ind)).uniform_() + random_subset_mask = random_subset_mask.ge( + self._track_query_false_negative_prob) + + prev_out_ind = prev_out_ind[random_subset_mask] + prev_target_ind = prev_target_ind[random_subset_mask] + + ## transfer matching from prev to current + # detected prev frame tracks + prev_track_ids = frame_target['track_ids'][prev_target_ind] + + # match track ids between frames + target_ind_match_matrix = prev_track_ids.unsqueeze(dim=1).eq(target['track_ids']) + target_ind_matching = target_ind_match_matrix.any(dim=1) + target_ind_matched_idx = target_ind_match_matrix.nonzero()[:, 1] + + # current frame track ids detected in the prev frame + # track_ids = target['track_ids'][target_ind_matched_idx] + + # index of prev frame detection in current frame box list + target['track_query_match_ids'] = target_ind_matched_idx + + if self.track_query_propagation_strategy == 'trackformer': + if not self.tracking_token_propagation: continue + not_prev_out_ind = torch.arange(prev_out['pred_boxes'].shape[1]) + not_prev_out_ind = [ + ind.item() + for ind in not_prev_out_ind + if ind not in prev_out_ind] + random_false_out_ind = [] + + # random false positives + prev_boxes_matched = prev_out['pred_boxes'][0, prev_out_ind[target_ind_matching]] + for prev_box_matched in prev_boxes_matched: + ## 从背景集合,随机挑选噪声 FP Augmentation + if random.uniform(0, 1) < self._track_query_false_positive_prob: + prev_boxes_unmatched = prev_out['pred_boxes'][0, not_prev_out_ind] + + # only cxcy + # box_dists = prev_box_matched[:2].sub(prev_boxes_unmatched[:, :2]).abs() + # box_dists = box_dists.pow(2).sum(dim=-1).sqrt() + # box_weights = 1.0 / box_dists.add(1e-8) + + prev_box_ious, _ = box_ops.box_iou( + box_ops.box_cxcywh_to_xyxy(prev_box_matched.unsqueeze(dim=0)), + box_ops.box_cxcywh_to_xyxy(prev_boxes_unmatched)) + box_weights = prev_box_ious[0] + + if box_weights.gt(0.0).any(): + random_false_out_idx = not_prev_out_ind.pop( + torch.multinomial(box_weights.cpu(), 1).item()) + random_false_out_ind.append(random_false_out_idx) + + prev_out_ind = torch.tensor(prev_out_ind.tolist() + random_false_out_ind).long() + target_ind_matching = torch.tensor( + target_ind_matching.tolist() + [False, ] * len(random_false_out_ind)).bool() + + # matches indices with 1.0 and not matched -1.0 + track_queries_match_mask = torch.ones_like(target_ind_matching).float() + track_queries_match_mask[~target_ind_matching] = -1.0 + + # set prev frame info + hs_embeds = prev_out['hs_embed'][0, prev_out_ind] + if self._track_query_noise and not torch.isnan(hs_embeds.std()).any(): + track_query_noise = torch.randn_like(hs_embeds) \ + * hs_embeds.std(dim=1, keepdim=True) + hs_embeds = hs_embeds + track_query_noise * self._track_query_noise + # hs_embeds = track_query_noise * self._track_query_noise \ + # + hs_embeds * (1 - self._track_query_noise) + target['track_query_hs_embeds'] = hs_embeds + target['track_query_boxes'] = prev_out['pred_boxes'][0, prev_out_ind].detach() + + # add zeros for detection object queries + device = track_queries_match_mask.device + track_queries_match_mask = torch.tensor( + track_queries_match_mask.tolist() + [0, ] * self.num_queries) + + target['track_queries_match_mask'] = track_queries_match_mask.to(device) + elif self.track_query_propagation_strategy == 'consistent_pairing': + device = target['track_query_match_ids'].device + tracked_qids = prev_out_ind[target_ind_matching] + + if self.tracking_match_propagation_skip_frame: # skip frame match propagation + target['prev_tid2qid'] = prev_tid2qid + if 'prev_tid2qid' in frame_target: + prev_prev_tids = set(frame_target['prev_tid2qid'].keys()) + prev_tids = set(frame_target['track_ids'].tolist()) + cur_tids = set(target['track_ids'].tolist()) + reappear_tids = list((cur_tids & prev_prev_tids) - prev_tids) + if len(reappear_tids) > 0: + # print(reappear_tids) + reappear_qids = [frame_target['prev_tid2qid'][tid] for tid in reappear_tids] + tracked_qids_append_reappear = tracked_qids.tolist() + reappear_qids + tracked_ids_append_reappear = target['track_query_match_ids'].tolist() + [target['track_ids'].tolist().index(tid) for tid in reappear_tids] + + new_order = np.argsort(tracked_qids_append_reappear) + tracked_qids = torch.tensor(np.array(tracked_qids_append_reappear)[new_order]) + target['track_query_match_ids'] = torch.tensor(np.array(tracked_ids_append_reappear)[new_order]).to(device) + + # match mask to next frame + track_queries_match_mask = torch.zeros(prev_out['hs_embed'][0].shape[0]).float() + track_queries_match_mask[tracked_qids] = 1 # tracked in current frame + track_queries_match_mask[prev_out_ind[~target_ind_matching]] = -1 # disappeared in current frame + target['track_queries_match_mask'] = track_queries_match_mask.to(device) + + if self.tracking_token_propagation: + target['track_query_hs_embeds'] = prev_out['hs_embed'][0] + target['track_query_hs_embeds_mapping'] = self.propagation_mlp(prev_out['hs_embed'][0]) + target['track_query_boxes'] = prev_out['pred_boxes'][0].detach() + + if random.random() < self.token_propagation_sample_rate: + target['track_token_propagation_mask'] = (prev_out['pred_logits'][0].softmax(-1)[:,:-1].max(-1)[0].detach() > self.detection_obj_score_thresh).float() + else: + target['track_token_propagation_mask'] = (target['track_queries_match_mask'] != 0).float() + + if self.track_query_propagation_strategy == 'consistent_pairing': + ## compose outputs in batched format + outputs = {key: torch.cat([o[key] for o in outs], dim=0) for key in ['pred_logits', 'pred_boxes', 'hs_embed']} + if 'aux_outputs' in outs[0]: + outputs['aux_outputs'] = [] + for l in range(len(outs[0]['aux_outputs'])): + outputs['aux_outputs'].append( + {key: torch.cat([o['aux_outputs'][l][key] for o in outs], dim=0) for key in ['pred_logits', 'pred_boxes']} + ) + return outputs, targets, None, None, None + else: + return outs, targets, None, None, None + + + def forward(self, samples: NestedTensor, targets: list = None): + if self._tracking: # tracking on inference + return self.foward_tracking_inference(samples, targets) + else: + if self.clip_mode: + return self.foward_train_val_clip_mode(samples, targets) + else: + return self.foward_train_val_2frame_mode(samples, targets) + +# TODO: with meta classes +class DETRTracking(DETRTrackingBase, DETR): + def __init__(self, tracking_kwargs, detr_kwargs): + DETR.__init__(self, **detr_kwargs) + DETRTrackingBase.__init__(self, **tracking_kwargs) + + +class DeformableDETRTracking(DETRTrackingBase, DeformableDETR): + def __init__(self, tracking_kwargs, detr_kwargs): + DeformableDETR.__init__(self, **detr_kwargs) + DETRTrackingBase.__init__(self, **tracking_kwargs) diff --git a/src/trackformer/models/detr_vsgg.py b/src/trackformer/models/detr_vsgg.py new file mode 100644 index 0000000..51fa17a --- /dev/null +++ b/src/trackformer/models/detr_vsgg.py @@ -0,0 +1,401 @@ +import copy +import random +import torch +import torch.nn as nn +import torchvision +from collections import OrderedDict + +from ..util.misc import NestedTensor +from ..util.box_ops import box_cxcywh_to_xyxy +from .deformable_detr import DeformableDETR +from .position_encoding import build_position_encoding, TemporalEmbeddingLearned +from .transformer import TransformerDecoderLayer, TransformerDecoder, TransformerEncoderLayer, TransformerEncoder + +class DeformableDETRVsgg(DeformableDETR): + def __init__(self, args, detr_kwargs, matcher): + DeformableDETR.__init__(self, **detr_kwargs) + assert args.track_query_propagation_strategy == 'consistent_pairing' + + self.args = args + self._matcher = matcher + self._tracking = False + + if args.tracking_token_propagation: + self.pos_propagation_mlp = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256) + ) + self.tgt_propagation_mlp = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256) + ) + + if args.hoi_detection: + if args.freeze_detr: + for p in self.parameters(): + p.requires_grad_(False) + else: + for n, p in self.named_parameters(): + if 'backbone.0.body' in n: + p.requires_grad_(False) + + # instance representation + if self.args.hoi_instance_fuse_spatial_and_semantic_feat: + self.spatial_dim, self.semantic_dim = 128, 128 + self.spatial_embed = nn.Linear(4, self.spatial_dim) + self.semantic_embed = nn.Embedding(args.num_classes+1, self.semantic_dim) + self.instance_representation_fuse = nn.Sequential( + nn.Linear(self.args.hidden_dim+self.spatial_dim+self.semantic_dim, self.args.hidden_dim), nn.ReLU() + ) + + # for rel/interactions prediction + rel_rep_dim = self.args.hidden_dim * 2 + self.relation_proposal_mlp = nn.Sequential( + nn.Linear(rel_rep_dim, rel_rep_dim // 2), nn.ReLU(), + nn.Linear(rel_rep_dim // 2, 1) + ) + self.rel_query_pre_proj = nn.Linear(rel_rep_dim, self.args.hidden_dim) + + rel_dec_hidden_dim = self.args.hidden_dim + self.memory_input_proj = nn.Conv2d(2048, self.args.hidden_dim, kernel_size=1) + self.rel_memory_pos = build_position_encoding(args) + + decoder_layer = TransformerDecoderLayer(d_model=rel_dec_hidden_dim, nhead=8) + decoder_norm = nn.LayerNorm(rel_dec_hidden_dim) + self.interaction_decoder = TransformerDecoder(decoder_layer, None, self.args.hoi_dec_layers, decoder_norm, return_intermediate=True) + self.relation_embed = nn.Linear(rel_dec_hidden_dim, self.args.num_relations) + + if self.args.hoi_use_temporal_dynamics: + encoder_layer = TransformerEncoderLayer(d_model=self.args.hidden_dim, nhead=4) + self.temporal_dynamic_encoder = TransformerEncoder(encoder_layer, num_layers=2) + self.temporal_position_encoding = TemporalEmbeddingLearned(self.args.hoi_use_temporal_dynamics_prev_length+1) + + if self.args.hoi_oracle_mode: + if self.args.hoi_oracle_mode_use_instant_trajectory: + self.traj_feat_dim = 4 * 24 * 2 # 4* frame_num * (subj+obj) + self.trajectory_feature_fc = nn.Linear(self.traj_feat_dim, self.traj_feat_dim) + self.fuse_trajectory_feature_mlp = nn.Sequential( + nn.Linear(self.args.hidden_dim + self.traj_feat_dim, self.args.hidden_dim), nn.ReLU(), + nn.Linear(self.args.hidden_dim, self.args.hidden_dim) + ) + if self.args.hoi_oracle_mode_use_roialign_union_feat: + self.fpn = torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork(in_channels_list=[256, 512, 1024, 2048], out_channels=256) + self.box_pooler = torchvision.ops.MultiScaleRoIAlign(['0', '1', '2', '3'], 7, sampling_ratio=2) + self.box_pool_fc = nn.Sequential(nn.Linear(256*7*7, self.args.hidden_dim), nn.ReLU()) + self.union_pool_fc = nn.Linear(self.args.hidden_dim*3, self.args.hidden_dim*2) + + def train(self, mode: bool = True): + """Sets the module in train mode.""" + self._tracking = False + return super().train(mode) + + def tracking(self): + """Sets the module in tracking mode.""" + self.eval() + self._tracking = True + + # online inference + def online_foward(self, samples: NestedTensor, targets: list = None): + assert len(targets) == 1 + if self.args.tracking_token_propagation and 'track_query_hs_embeds' in targets[0]: + targets[0]['track_query_hs_embeds_pos_mapping'] = self.pos_propagation_mlp(targets[0]['track_query_hs_embeds']) + targets[0]['track_query_hs_embeds_tgt_mapping'] = self.tgt_propagation_mlp(targets[0]['track_query_hs_embeds']) + + out, out_targets, features, memory, hs = super().forward(samples, targets) + if self.args.hoi_detection: + if self.args.hoi_oracle_mode: # set Oracle mode (given GT objects) + match_res = self._matcher(out, targets, only_match_by_bbox=self.args.hoi_oracle_mode_only_given_bbox)[0] + # match_res = self._matcher(out, targets)[0] + if not self.args.hoi_oracle_mode_only_given_bbox: # PredCls mode, otherwise SGCls + out['pred_logits'][0, :, -1] = 1e3 # BG class + out['pred_logits'][0, match_res[0], targets[0]['labels'][match_res[1]]] = 1e6 + out['pred_boxes'][0, match_res[0]] = targets[0]['boxes'][match_res[1]] + out['match_res'] = match_res + out = self.hoi_forward(out, features, targets[0], image_size=samples.tensors.shape[-2:]) + return out, out_targets, features, memory, hs + + def forward(self, samples: NestedTensor, targets: list = None): + if self._tracking: + return self.online_foward(samples, targets) + + outs = [] + for frame_id, frame_target in enumerate(targets): + frame_image = samples.tensors[frame_id] + out, _, features_all, _, _ = super().forward([frame_image], [frame_target]) + + if self.args.hoi_detection: + if self.args.hoi_oracle_mode: # set Oracle mode (given GT objects) + match_res = self._matcher(out, [frame_target], only_match_by_bbox=self.args.hoi_oracle_mode_only_given_bbox)[0] + # match_res = self._matcher(out, [frame_target])[0] + if not self.args.hoi_oracle_mode_only_given_bbox: # PredCls mode, otherwise SGCls + out['pred_logits'][0, :, -1] = 1e3 + out['pred_logits'][0, match_res[0], frame_target['labels'][match_res[1]]] = 1e6 + out['pred_boxes'][0, match_res[0]] = frame_target['boxes'][match_res[1]] + out['match_res'] = match_res + + if 'aux_outputs' in outs: # aux outputs + for o in outs['aux_outputs']: + if not self.args.hoi_oracle_mode_only_given_bbox: + o['pred_logits'][0, :, -1] = 1e3 + o['pred_logits'][0, match_res[0], frame_target['labels'][match_res[1]]] = 1e6 + o['pred_boxes'][0, match_res[0]] = frame_target['boxes'][match_res[1]] + out = self.hoi_forward(out, features_all, frame_target, image_size=frame_image.shape[-2:]) + outs.append(out) + + # propagate to the next frame + if frame_id < len(targets)-1: + prev_out = out + target = targets[frame_id+1] # target for the next frame + if 'match_res' in prev_out: + prev_out_ind, prev_target_ind = prev_out['match_res'] + else: + prev_indices = self._matcher(prev_out, [frame_target]) + prev_out_ind, prev_target_ind = prev_indices[0] + + ## transfer matching from prev to current + # detected prev frame tracks + prev_track_ids = frame_target['track_ids'][prev_target_ind] + + # match track ids between frames + target_ind_match_matrix = prev_track_ids.unsqueeze(dim=1).eq(target['track_ids']) + target_ind_matching = target_ind_match_matrix.any(dim=1) + target_ind_matched_idx = target_ind_match_matrix.nonzero()[:, 1] + + # index of prev frame detection in current frame box list + target['track_query_match_ids'] = target_ind_matched_idx + device = target['track_query_match_ids'].device + tracked_qids = prev_out_ind[target_ind_matching] + + # match mask to next frame + track_queries_match_mask = torch.zeros(prev_out['hs_embed'][0].shape[0]).float() + track_queries_match_mask[tracked_qids] = 1 # tracked in current frame + track_queries_match_mask[prev_out_ind[~target_ind_matching]] = -1 # disappeared in current frame + target['track_queries_match_mask'] = track_queries_match_mask.to(device) + + if self.args.tracking_token_propagation: + target['track_query_hs_embeds'] = prev_out['hs_embed'][0] + target['track_query_hs_embeds_pos_mapping'] = self.pos_propagation_mlp(prev_out['hs_embed'][0]) + target['track_query_hs_embeds_tgt_mapping'] = self.tgt_propagation_mlp(prev_out['hs_embed'][0]) + target['track_query_boxes'] = prev_out['pred_boxes'][0].detach() + if random.random() < self.args.token_propagation_sample_rate: + target['track_token_propagation_mask'] = (prev_out['pred_logits'][0].softmax(-1)[:,:-1].max(-1)[0].detach() > 0.7).float() + else: + target['track_token_propagation_mask'] = (target['track_queries_match_mask'] != 0).float() + + if self.args.hoi_detection and self.args.hoi_use_temporal_dynamics: + if 'temporal_dynamics_feature_bank' in frame_target: + target['temporal_dynamics_feature_bank'] = torch.cat((prev_out['hs_embed'], frame_target['temporal_dynamics_feature_bank']), dim=0)[:self.args.hoi_use_temporal_dynamics_prev_length] + target['temporal_dynamics_feature_mask'] = torch.cat(((target['track_token_propagation_mask']==0).unsqueeze(0), frame_target['temporal_dynamics_feature_mask']), dim=0)[:self.args.hoi_use_temporal_dynamics_prev_length] + else: + target['temporal_dynamics_feature_bank'] = prev_out['hs_embed'] + target['temporal_dynamics_feature_mask'] = (target['track_token_propagation_mask']==0).unsqueeze(0) + + ## compose outputs in batched format + outputs = {key: torch.cat([o[key] for o in outs], dim=0) for key in ['pred_logits', 'pred_boxes', 'hs_embed']} + if 'aux_outputs' in outs[0]: + outputs['aux_outputs'] = [] + for l in range(len(outs[0]['aux_outputs'])): + outputs['aux_outputs'].append( + {key: torch.cat([o['aux_outputs'][l][key] for o in outs], dim=0) for key in ['pred_logits', 'pred_boxes']} + ) + + if self.args.hoi_detection: + outputs.update({key: [o[key][0] for o in outs] for key in ['pred_rel_pairs', 'pred_relations', 'pred_relation_exists']}) + if 'relation_aux_outputs' in outs[0]: + outputs['relation_aux_outputs'] = [] + for l in range(len(outs[0]['relation_aux_outputs'])): + outputs['relation_aux_outputs'].append( + {key: [o['relation_aux_outputs'][l][key][0] for o in outs] for key in ['pred_relations']} + ) + + return outputs, targets, None, None, None + + def hoi_forward(self, out, features_all, frame_target=None, image_size=None): + assert len(features_all[-1].tensors) == 1 # frame-wise forward + + # memory input for relation transformer decoder + memory_input_feature, memory_input_mask = features_all[-1].decompose() + memory_pos = self.rel_memory_pos(features_all[-1]) + memory_input = self.memory_input_proj(memory_input_feature) + + # instance representations + if self.args.hoi_use_temporal_dynamics: + if (frame_target is not None) and 'temporal_dynamics_feature_bank' in frame_target: + src = torch.cat((out['hs_embed'], frame_target['temporal_dynamics_feature_bank']), dim=0) + att_mask = torch.cat((torch.zeros_like(frame_target['temporal_dynamics_feature_mask'])[:1], frame_target['temporal_dynamics_feature_mask']), dim=0).permute(1,0) + pos_idx = torch.arange(len(src)).to(src.device).unsqueeze(-1) + else: + src = out['hs_embed'] + att_mask = torch.zeros((src.shape[1], 1), dtype=torch.bool, device=src.device) + pos_idx = torch.arange(1).to(src.device).unsqueeze(-1) + + instance_representations = self.temporal_dynamic_encoder( + src=src, + src_key_padding_mask=att_mask, + pos=self.temporal_position_encoding(pos_idx), + )[0] + else: + instance_representations = out['hs_embed'][0] + + if self.training: + if 'match_res' in out: + ds, gs = out['match_res'] + else: + ds, gs = self._matcher(out, [frame_target])[0] + gt2det_map = torch.zeros(len(gs)).to(device=ds.device, dtype=ds.dtype) + gt2det_map[gs] = ds + gt_rel_pairs = gt2det_map[frame_target['relation_map'].sum(-1).nonzero(as_tuple=False)] + + imgid, num_nodes= 0, self.args.num_queries + # >>>>>>>>>>>> relation proposal <<<<<<<<<<<<<<< + probs = out['pred_logits'][imgid].softmax(-1) + inst_scores, inst_labels = probs[:, :-1].max(-1) + human_instance_ids = torch.logical_and(inst_scores>0.1, inst_labels==0).nonzero(as_tuple=False) # class0: person + + rel_mat = torch.zeros((num_nodes, num_nodes)) + rel_mat[human_instance_ids] = 1 + if self.args.hoi_oracle_mode: + gt_mask = torch.zeros_like(rel_mat) + gt_mask[out['match_res'][0]] += 1; gt_mask[:, out['match_res'][0]] += 1 + rel_mat[gt_mask!=2] = 0 + + if self.training: + # sampling + if self.args.hoi_oracle_mode: + rel_mat[gt_rel_pairs[:, :1], ds] = 1 + else: + rel_mat[gt_rel_pairs[:, :1]] = 1 + rel_mat[gt_rel_pairs[:, 0], gt_rel_pairs[:, 1]] = 0 + rel_mat.fill_diagonal_(0) + rel_pairs = rel_mat.nonzero(as_tuple=False) # neg pairs + + if self.args.hoi_hard_mining: + all_pairs = torch.cat([gt_rel_pairs, rel_pairs], dim=0) + gt_pair_count = len(gt_rel_pairs) + all_rel_reps = self._build_relation_representations(instance_representations, out, all_pairs, imgid, features_all=features_all, image_size=image_size) + p_relation_exist_logits = self.relation_proposal_mlp(all_rel_reps) + + gt_inds = torch.arange(gt_pair_count).to(p_relation_exist_logits.device) + # _, sort_rel_inds = p_relation_exist_logits.sigmoid()[gt_pair_count:].squeeze(1).sort(descending=True) + _, sort_rel_inds = torch.cat([inst_scores[all_pairs], p_relation_exist_logits.sigmoid()], dim=-1).prod(-1)[gt_pair_count:].sort(descending=True) + sampled_rel_inds = torch.cat([gt_inds, sort_rel_inds+gt_pair_count])[:self.args.num_hoi_queries] + + sampled_rel_pairs = all_pairs[sampled_rel_inds] + sampled_rel_reps = all_rel_reps[sampled_rel_inds] + sampled_rel_pred_exists = p_relation_exist_logits.squeeze(1)[sampled_rel_inds] + else: + sampled_neg_inds = torch.randperm(len(rel_pairs)) # random sampling + sampled_rel_pairs = torch.cat([gt_rel_pairs, rel_pairs[sampled_neg_inds]], dim=0)[:self.args.num_hoi_queries] + sampled_rel_reps = self._build_relation_representations(instance_representations, out, sampled_rel_pairs, imgid, features_all=features_all, image_size=image_size) + sampled_rel_pred_exists = self.relation_proposal_mlp(sampled_rel_reps).squeeze(1) + else: + if self.args.hoi_relation_propagation_on_inference and 'prev_top_rel_pairs' in frame_target: + prev_rel_pairs = frame_target['prev_top_rel_pairs'].cpu() + else: + prev_rel_pairs = torch.zeros((0, 2)).long() + prev_pair_count = len(prev_rel_pairs) + prev_pair_inds = torch.arange(prev_pair_count) + + if not self.args.hoi_oracle_mode and self.args.hoi_inference_apply_nms: + bg_inds = self.apply_nms(inst_scores, inst_labels, out['pred_boxes'][imgid]) + rel_mat[:, bg_inds] = 0 + rel_mat[prev_rel_pairs[:, 0], prev_rel_pairs[:, 1]] = 0 + rel_mat.fill_diagonal_(0) + rel_pairs = rel_mat.nonzero(as_tuple=False) + + # predict interactiveness and sorting + rel_pairs = torch.cat([prev_rel_pairs, rel_pairs], dim=0) + rel_reps = self._build_relation_representations(instance_representations, out, rel_pairs, imgid, features_all=features_all, image_size=image_size) + p_relation_exist_logits = self.relation_proposal_mlp(rel_reps) + + # _, sort_rel_inds = p_relation_exist_logits.sigmoid()[prev_pair_count:].squeeze(1).sort(descending=True) + _, sort_rel_inds = torch.cat([inst_scores[rel_pairs], p_relation_exist_logits.sigmoid()], dim=-1).prod(-1)[prev_pair_count:].sort(descending=True) + sampled_rel_inds = torch.cat([prev_pair_inds.to(sort_rel_inds.device), + sort_rel_inds[:self.args.num_hoi_queries] + prev_pair_count]) + + sampled_rel_pairs = rel_pairs[sampled_rel_inds] + sampled_rel_reps = rel_reps[sampled_rel_inds] + sampled_rel_pred_exists = p_relation_exist_logits.squeeze(1)[sampled_rel_inds] + + # >>>>>>>>>>>> relation classification <<<<<<<<<<<<<<< + query_reps = self.rel_query_pre_proj(sampled_rel_reps).unsqueeze(1) + relation_outs, _ = self.interaction_decoder(tgt=query_reps, + memory=memory_input[imgid:imgid+1].flatten(2).permute(2,0,1), + memory_key_padding_mask=memory_input_mask[imgid:imgid+1].flatten(1), + pos=memory_pos[imgid:imgid+1].flatten(2).permute(2, 0, 1)) + if self.args.hoi_oracle_mode and self.args.hoi_oracle_mode_use_instant_trajectory: # for fair comparison with ST-HOI of fusing GT trajectory feature + ds2gs = torch.zeros(self.args.num_queries).long() - 1 + ds2gs[out['match_res'][0]] = out['match_res'][1] + traj_feats = frame_target['box_instant_trajectories'][ds2gs[sampled_rel_pairs]].view(-1, self.traj_feat_dim) + relation_outs = self.fuse_trajectory_feature_mlp( + torch.cat([relation_outs, self.trajectory_feature_fc(traj_feats.unsqueeze(0).unsqueeze(-2)).expand(len(relation_outs),-1,-1,-1)], dim=-1) + ) + + relation_logits = self.relation_embed(relation_outs) + out.update({ + "pred_rel_pairs": [sampled_rel_pairs], + "pred_relations": [relation_logits[-1].squeeze(1)], + "pred_relation_exists": [sampled_rel_pred_exists], + }) + + if self.args.hoi_aux_loss: + out['relation_aux_outputs'] = self._set_hoi_aux_loss([relation_logits]) + return out + + @torch.jit.unused + def _set_hoi_aux_loss(self, pred_relations): + return [{'pred_relations': [p[l].squeeze(1) for p in pred_relations]} for l in range(self.args.hoi_dec_layers - 1)] + + # merge boxes (NMS) + def apply_nms(self, inst_scores, inst_labels, cxcywh_boxes, threshold=0.7): + xyxy_boxes = box_cxcywh_to_xyxy(cxcywh_boxes) + box_areas = (xyxy_boxes[:, 2:] - xyxy_boxes[:, :2]).prod(-1) + box_area_sum = box_areas.unsqueeze(1) + box_areas.unsqueeze(0) + + union_boxes = torch.cat([torch.min(xyxy_boxes.unsqueeze(1)[:, :, :2], xyxy_boxes.unsqueeze(0)[:, :, :2]), + torch.max(xyxy_boxes.unsqueeze(1)[:, :, 2:], xyxy_boxes.unsqueeze(0)[:, :, 2:])], dim=-1) + union_area = (union_boxes[:,:,2:] - union_boxes[:,:,:2]).prod(-1) + iou = torch.clamp(box_area_sum - union_area, min=0) / union_area + box_match_mat = torch.logical_and(iou > threshold, inst_labels.unsqueeze(1) == inst_labels.unsqueeze(0)) + + suppress_ids = [] + for box_match in box_match_mat: + group_ids = box_match.nonzero(as_tuple=False).squeeze(1) + if len(group_ids) > 1: + max_score_inst_id = group_ids[inst_scores[group_ids].argmax()] + bg_ids = group_ids[group_ids!=max_score_inst_id] + suppress_ids.append(bg_ids) + box_match_mat[:, bg_ids] = False + if len(suppress_ids) > 0: + suppress_ids = torch.cat(suppress_ids, dim=0) + return suppress_ids + + def _build_relation_representations(self, inst_reps, outs, rel_pairs, imgid, features_all=None, image_size=None): + if self.args.hoi_instance_fuse_spatial_and_semantic_feat: + inst_spatial_reps = self.spatial_embed(outs['pred_boxes'][imgid]) + inst_semantic_reps = outs['pred_logits'][imgid].softmax(-1) @ self.semantic_embed.weight + inst_reps = self.instance_representation_fuse(torch.cat([inst_reps, inst_spatial_reps, inst_semantic_reps], dim=-1)) + + rel_reps = torch.cat([inst_reps[rel_pairs[:,0]], inst_reps[rel_pairs[:,1]]], dim=1) + + # fuse roi_align union feature + if self.args.hoi_oracle_mode and self.args.hoi_oracle_mode_use_roialign_union_feat: + feat_order_dict = OrderedDict() + for lvl, feat in enumerate(features_all): + feat_order_dict[str(lvl)] = feat.tensors + fpn_feats = self.fpn(feat_order_dict) + + xyxy_boxes = box_cxcywh_to_xyxy(outs['pred_boxes'][imgid]) + img_h, img_w = image_size + scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(xyxy_boxes.device) + xyxy_boxes = xyxy_boxes * scale_fct[None, :] + subj_boxes, obj_boxes = xyxy_boxes[rel_pairs[:, 0]], xyxy_boxes[rel_pairs[:, 1]] + union_boxes = torch.cat([torch.min(subj_boxes[:, :2], obj_boxes[:, :2]), torch.max(subj_boxes[:, 2:], obj_boxes[:, 2:])], dim=-1) + union_pool_feats = self.box_pool_fc( + self.box_pooler(fpn_feats, [union_boxes], [image_size]).view(-1, 256*7*7) + ) + rel_reps = self.union_pool_fc(torch.cat([rel_reps, union_pool_feats], dim=-1)) + + return rel_reps diff --git a/src/trackformer/models/matcher.py b/src/trackformer/models/matcher.py new file mode 100644 index 0000000..3e80903 --- /dev/null +++ b/src/trackformer/models/matcher.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import numpy as np +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from ..util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best + predictions, while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, args, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, + focal_loss: bool = False, focal_alpha: float = 0.25): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates + in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the + matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + self.focal_loss = focal_loss + self.focal_alpha = focal_alpha + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + self.args = args + + @torch.no_grad() + def forward(self, outputs, targets, only_match_by_bbox=False): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the + classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted + box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target + is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number + of ground-truth objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + batch_size, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + # + # [batch_size * num_queries, num_classes] + if self.focal_loss: + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() + else: + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) + + # [batch_size * num_queries, 4] + out_bbox = outputs["pred_boxes"].flatten(0, 1) + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. + if self.focal_loss: + gamma = 2.0 + neg_cost_class = (1 - self.focal_alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = self.focal_alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + if only_match_by_bbox: # Modification for SGCls mode in SGG task (no given tgt_ids) + cost_class *= 0 + else: + # Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be omitted. + cost_class = -out_prob[:, tgt_ids] + if only_match_by_bbox: # Modification for SGCls mode in SGG task (no given tgt_ids) + cost_class = out_prob[:, -1:] # prob[BG] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou( + box_cxcywh_to_xyxy(out_bbox), + box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + cost_matrix = self.cost_bbox * cost_bbox \ + + self.cost_class * cost_class \ + + self.cost_giou * cost_giou + cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + + ## propagate prev match to current frame (for training tracking) + if self.args.tracking_match_propagation: + cost_min = min(cost_matrix.min(), -1) if cost_matrix.numel() > 0 else -1 + for i, target in enumerate(targets): + if 'track_query_match_ids' not in target: + continue + + prop_i = 0 + for j, mask_value in enumerate(target['track_queries_match_mask']): + if mask_value.item() == 1: + track_query_id = target['track_query_match_ids'][prop_i] # matched box_id + prop_i += 1 + + cost_matrix[i, j] = np.inf # img X queries X boxes + cost_matrix[i, :, track_query_id + sum(sizes[:i])] = np.inf + cost_matrix[i, j, track_query_id + sum(sizes[:i])] = cost_min + elif mask_value.item() == -1: + cost_matrix[i, j] = np.inf + + indices = [linear_sum_assignment(c[i]) + for i, c in enumerate(cost_matrix.split(sizes, -1))] + + # ## check match transfer for consistent_pairing + # for fid, (inds, t) in enumerate(zip(indices, targets)): + # qids, obj_ids = inds + # if fid > 0: + # print('\n') + # lost_qids = sorted(list(set(match_dict.keys()) - set(qids))) + # for q in lost_qids: + # print(f"Lost: qid={q} -> {match_dict[q]}") + # for q, o in zip(qids, obj_ids): + # if q in match_dict: # tracked obj + # if match_dict[q] == t['track_ids'][o]: + # print(f"Tracked: qid={q} ({match_dict[q]} -> {t['track_ids'][o]})") + # else: + # print(f"!!Match shift: qid={q} ({match_dict[q]} -> {t['track_ids'][o]})") + # else: # queries for new objs + # print(f"New: qid={q} -> {t['track_ids'][o]}") + # + # match_dict = {q: t['track_ids'][o] for q, o in zip(qids, obj_ids)} + + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices] + + +def build_matcher(args): + return HungarianMatcher( + args, + cost_class=args.set_cost_class, + cost_bbox=args.set_cost_bbox, + cost_giou=args.set_cost_giou, + focal_loss=args.focal_loss, + focal_alpha=args.focal_alpha, + ) diff --git a/src/trackformer/models/ops/.gitignore b/src/trackformer/models/ops/.gitignore new file mode 100644 index 0000000..1fe80a2 --- /dev/null +++ b/src/trackformer/models/ops/.gitignore @@ -0,0 +1,5 @@ +build +dist +*egg-info +*.linux* +*.win* diff --git a/src/trackformer/models/ops/functions/__init__.py b/src/trackformer/models/ops/functions/__init__.py new file mode 100644 index 0000000..8a2197b --- /dev/null +++ b/src/trackformer/models/ops/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/src/trackformer/models/ops/functions/ms_deform_attn_func.py b/src/trackformer/models/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..8c5df8c --- /dev/null +++ b/src/trackformer/models/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/src/trackformer/models/ops/make.sh b/src/trackformer/models/ops/make.sh new file mode 100644 index 0000000..106b685 --- /dev/null +++ b/src/trackformer/models/ops/make.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +python setup.py build install diff --git a/src/trackformer/models/ops/modules/__init__.py b/src/trackformer/models/ops/modules/__init__.py new file mode 100644 index 0000000..f82cb1a --- /dev/null +++ b/src/trackformer/models/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/src/trackformer/models/ops/modules/ms_deform_attn.py b/src/trackformer/models/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000..663d64a --- /dev/null +++ b/src/trackformer/models/ops/modules/ms_deform_attn.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/src/trackformer/models/ops/setup.py b/src/trackformer/models/ops/setup.py new file mode 100644 index 0000000..ac583dc --- /dev/null +++ b/src/trackformer/models/ops/setup.py @@ -0,0 +1,71 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not available') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages(exclude=("configs", "tests",)), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000..e1bf854 --- /dev/null +++ b/src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,41 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.h b/src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000..81b7b58 --- /dev/null +++ b/src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.cu b/src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000..d6d5836 --- /dev/null +++ b/src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.h b/src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000..c7ae53f --- /dev/null +++ b/src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/src/trackformer/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/src/trackformer/models/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000..6bc2acb --- /dev/null +++ b/src/trackformer/models/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/src/trackformer/models/ops/src/ms_deform_attn.h b/src/trackformer/models/ops/src/ms_deform_attn.h new file mode 100644 index 0000000..ac0ef2e --- /dev/null +++ b/src/trackformer/models/ops/src/ms_deform_attn.h @@ -0,0 +1,62 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/src/trackformer/models/ops/src/vision.cpp b/src/trackformer/models/ops/src/vision.cpp new file mode 100644 index 0000000..2201f63 --- /dev/null +++ b/src/trackformer/models/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/src/trackformer/models/ops/test.py b/src/trackformer/models/ops/test.py new file mode 100644 index 0000000..8dbf6d5 --- /dev/null +++ b/src/trackformer/models/ops/test.py @@ -0,0 +1,89 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H*W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) + + print(f'* {gradok} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) + + + diff --git a/src/trackformer/models/position_encoding.py b/src/trackformer/models/position_encoding.py new file mode 100644 index 0000000..fc8edb0 --- /dev/null +++ b/src/trackformer/models/position_encoding.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from ..util.misc import NestedTensor +import numpy as np + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + # dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + dim_t = torch.tensor(np.arange(self.num_pos_feats) // 2, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * dim_t / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack(( + pos_x[:, :, :, 0::2].sin(), + pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack(( + pos_y[:, :, :, 0::2].sin(), + pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +class TemporalEmbeddingLearned(nn.Module): + """ + Absolute temporal embedding, learned. + """ + def __init__(self, num_embeds, num_pos_feats=256): + super().__init__() + self.temporal_embed = nn.Embedding(num_embeds, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.temporal_embed.weight) + + def forward(self, temporal_differences): + temporal_embed = self.temporal_embed(temporal_differences) + return temporal_embed + + +def build_position_encoding(args): + n_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(n_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(n_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/src/trackformer/models/tracker.py b/src/trackformer/models/tracker.py new file mode 100644 index 0000000..7fd44bf --- /dev/null +++ b/src/trackformer/models/tracker.py @@ -0,0 +1,656 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Tracker which achieves MOT with the provided object detector. +""" +from collections import deque + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from torchvision.ops.boxes import clip_boxes_to_image, nms, box_iou + +from ..util.box_ops import box_xyxy_to_cxcywh +from ..util.plot_utils import check_prediction + + +class Tracker: + """The main tracking file, here is where magic happens.""" + + def __init__(self, obj_detector, obj_detector_post, tracker_cfg, + generate_attention_maps, logger=None): + self.obj_detector = obj_detector + self.obj_detector_post = obj_detector_post + self.detection_obj_score_thresh = tracker_cfg['detection_obj_score_thresh'] + self.track_obj_score_thresh = tracker_cfg['track_obj_score_thresh'] + self.detection_nms_thresh = tracker_cfg['detection_nms_thresh'] + self.track_nms_thresh = tracker_cfg['track_nms_thresh'] + self.public_detections = tracker_cfg['public_detections'] + self.inactive_patience = float(tracker_cfg['inactive_patience']) + self.reid_sim_threshold = tracker_cfg['reid_sim_threshold'] + self.reid_sim_only = tracker_cfg['reid_sim_only'] + self.generate_attention_maps = generate_attention_maps + self.reid_score_thresh = tracker_cfg['reid_score_thresh'] + self.reid_greedy_matching = tracker_cfg['reid_greedy_matching'] + + self.consistent_pairing_detection_thresh = tracker_cfg['consistent_pairing_detection_thresh'] + self.prev_outs = None + + if self.generate_attention_maps: + assert hasattr(self.obj_detector.transformer.decoder.layers[-1], 'multihead_attn'), 'Generation of attention maps not possible for deformable DETR.' + + attention_data = { + 'maps': None, + 'conv_features': {}, + 'hooks': []} + + hook = self.obj_detector.backbone[-2].register_forward_hook( + lambda self, input, output: attention_data.update({'conv_features': output})) + attention_data['hooks'].append(hook) + + def add_attention_map_to_data(self, input, output): + height, width = attention_data['conv_features']['3'].tensors.shape[-2:] + attention_maps = output[1].view(-1, height, width) + + attention_data.update({'maps': attention_maps}) + + multihead_attn = self.obj_detector.transformer.decoder.layers[-1].multihead_attn + hook = multihead_attn.register_forward_hook( + add_attention_map_to_data) + attention_data['hooks'].append(hook) + + self.attention_data = attention_data + + self._logger = logger + if self._logger is None: + self._logger = lambda *log_strs: None + + @property + def num_object_queries(self): + return self.obj_detector.num_queries + + def reset(self, hard=True): + self.tracks = [] + self.inactive_tracks = [] + self._prev_blob = None + self._prev_frame = None + + if hard: + self.track_num = 0 + self.results = {} + self.frame_index = 0 + self.num_reids = 0 + + @property + def device(self): + return next(self.obj_detector.parameters()).device + + def tracks_to_inactive(self, tracks): + self.tracks = [t for t in self.tracks if t not in tracks] + + for track in tracks: + track.pos = track.last_pos[-1] + self.inactive_tracks += tracks + + def add_tracks(self, pos, scores, hs_embeds, masks=None, attention_maps=None, query_ids=None): + """Initializes new Track objects and saves them.""" + new_track_ids = [] + for i in range(len(pos)): + self.tracks.append(Track( + pos[i], + scores[i], + self.track_num + i, + hs_embeds[i], + None if masks is None else masks[i], + None if attention_maps is None else attention_maps[i], + )) + new_track_ids.append(self.track_num + i) + + if query_ids is not None: + self.tracks[-1].query_id = query_ids[i] + self.track_num += len(new_track_ids) + + if new_track_ids: + self._logger( + f'INIT TRACK IDS (detection_obj_score_thresh={self.detection_obj_score_thresh}): ' + f'{new_track_ids}') + + return new_track_ids + + def public_detections_mask(self, new_det_boxes, public_det_boxes): + """Returns mask to filter current frame detections with provided set of + public detections.""" + + if not self.public_detections: + return torch.ones(new_det_boxes.size(0)).bool().to(self.device) + + if not len(public_det_boxes) or not len(new_det_boxes): + return torch.zeros(new_det_boxes.size(0)).bool().to(self.device) + + public_detections_mask = torch.zeros(new_det_boxes.size(0)).bool().to(self.device) + + if self.public_detections == 'center_distance': + item_size = [((box[2] - box[0]) * (box[3] - box[1])) + for box in new_det_boxes] + item_size = np.array(item_size, np.float32) + + new_det_boxes_cxcy = box_xyxy_to_cxcywh(new_det_boxes).cpu().numpy()[:,:2] + public_det_boxes_cxcy = box_xyxy_to_cxcywh(public_det_boxes).cpu().numpy()[:,:2] + + dist3 = new_det_boxes_cxcy.reshape(-1, 1, 2) - public_det_boxes_cxcy.reshape(1, -1, 2) + dist3 = (dist3 ** 2).sum(axis=2) + + for j in range(len(public_det_boxes)): + i = dist3[:, j].argmin() + + if dist3[i, j] < item_size[i]: + dist3[i, :] = 1e18 + public_detections_mask[i] = True + elif self.public_detections == 'min_iou_0_5': + iou_matrix = box_iou(new_det_boxes, public_det_boxes.to(self.device)) + + for j in range(len(public_det_boxes)): + i = iou_matrix[:, j].argmax() + + if iou_matrix[i, j] >= 0.5: + iou_matrix[i, :] = 0 + public_detections_mask[i] = True + else: + raise NotImplementedError + + return public_detections_mask + + def reid(self, new_det_boxes, new_det_scores, new_det_hs_embeds, + new_det_masks=None, new_det_attention_maps=None): + """Tries to ReID inactive tracks with provided detections.""" + + self.inactive_tracks = [ + t for t in self.inactive_tracks + if t.has_positive_area() and t.count_inactive <= self.inactive_patience + ] + + if not self.inactive_tracks or not len(new_det_boxes): + return torch.ones(new_det_boxes.size(0)).bool().to(self.device) + + # calculate distances + dist_mat = [] + if self.reid_greedy_matching: + new_det_boxes_cxcyhw = box_xyxy_to_cxcywh(new_det_boxes).cpu().numpy() + inactive_boxes_cxcyhw = box_xyxy_to_cxcywh(torch.stack([ + track.pos for track in self.inactive_tracks])).cpu().numpy() + + dist_mat = inactive_boxes_cxcyhw[:, :2].reshape(-1, 1, 2) - \ + new_det_boxes_cxcyhw[:, :2].reshape(1, -1, 2) + dist_mat = (dist_mat ** 2).sum(axis=2) + + track_size = inactive_boxes_cxcyhw[:, 2] * inactive_boxes_cxcyhw[:, 3] + item_size = new_det_boxes_cxcyhw[:, 2] * new_det_boxes_cxcyhw[:, 3] + + invalid = ((dist_mat > track_size.reshape(len(track_size), 1)) + \ + (dist_mat > item_size.reshape(1, len(item_size)))) + dist_mat = dist_mat + invalid * 1e18 + + def greedy_assignment(dist): + matched_indices = [] + if dist.shape[1] == 0: + return np.array(matched_indices, np.int32).reshape(-1, 2) + for i in range(dist.shape[0]): + j = dist[i].argmin() + if dist[i][j] < 1e16: + dist[:, j] = 1e18 + dist[i, j] = 0.0 + matched_indices.append([i, j]) + return np.array(matched_indices, np.int32).reshape(-1, 2) + + matched_indices = greedy_assignment(dist_mat) + row_indices, col_indices = matched_indices[:, 0], matched_indices[:, 1] + + else: + for track in self.inactive_tracks: + track_sim = track.hs_embed[-1] + + track_sim_dists = torch.cat([ + F.pairwise_distance(track_sim, sim.unsqueeze(0)) + for sim in new_det_hs_embeds]) + + dist_mat.append(track_sim_dists) + + dist_mat = torch.stack(dist_mat) + + dist_mat = dist_mat.cpu().numpy() + row_indices, col_indices = linear_sum_assignment(dist_mat) + + assigned_indices = [] + remove_inactive = [] + for row_ind, col_ind in zip(row_indices, col_indices): + if dist_mat[row_ind, col_ind] <= self.reid_sim_threshold: + track = self.inactive_tracks[row_ind] + + self._logger( + f'REID: track.id={track.id} - ' + f'count_inactive={track.count_inactive} - ' + f'to_inactive_frame={self.frame_index - track.count_inactive}') + + track.count_inactive = 0 + track.pos = new_det_boxes[col_ind] + track.score = new_det_scores[col_ind] + track.hs_embed.append(new_det_hs_embeds[col_ind]) + track.reset_last_pos() + + if new_det_masks is not None: + track.mask = new_det_masks[col_ind] + if new_det_attention_maps is not None: + track.attention_map = new_det_attention_maps[col_ind] + + assigned_indices.append(col_ind) + remove_inactive.append(track) + + self.tracks.append(track) + + self.num_reids += 1 + + for track in remove_inactive: + self.inactive_tracks.remove(track) + + reid_mask = torch.ones(new_det_boxes.size(0)).bool().to(self.device) + + for ind in assigned_indices: + reid_mask[ind] = False + + return reid_mask + + # step for track_query_propagation_strategy == 'consistent_pairing' + def consistent_pairing_step(self, blob): + self._logger(f'FRAME: {self.frame_index + 1}') + img = blob['img'].to(self.device) + orig_size = blob['orig_size'].to(self.device) + + ## detection results + target = None + num_prev_track = len(self.tracks) + if num_prev_track > 0: + target = [self.prev_outs] + + outputs, *_ = self.obj_detector(img, target) + # check_prediction(img, outputs, frame_id=self.frame_index+1) + + results = self.obj_detector_post['bbox'](outputs, orig_size) + result = results[0] + boxes = clip_boxes_to_image(result['boxes'], orig_size[0]) + + current_det_scores = result['scores'] + current_det_boxes = boxes + current_det_keep = torch.logical_and( + current_det_scores > self.consistent_pairing_detection_thresh, + result['labels'] == 0) # label is person + + ## TRACKS + if num_prev_track: + prev_track_query_ids = set([t.query_id for t in self.tracks]) + current_det_query_ids = set(current_det_keep.nonzero().squeeze(-1).cpu().numpy()) + + new_ids = sorted(list(current_det_query_ids - prev_track_query_ids)) + tracked_ids = current_det_query_ids & prev_track_query_ids + inactive_ids = prev_track_query_ids - current_det_query_ids + + for t in self.tracks: + if t.query_id in inactive_ids: + t.count_inactive += 1 + elif t.query_id in tracked_ids: + t.count_inactive = 0 + t.pos = current_det_boxes[t.query_id] + t.score = current_det_scores[t.query_id] + + # remove tracks inactive for long time + self.tracks = [t for t in self.tracks if t.count_inactive <= self.inactive_patience] + + new_track_ids = [] + if len(new_ids) > 0: + new_ids_keep = torch.tensor(new_ids) + new_track_ids = self.add_tracks( + current_det_boxes[new_ids_keep], + current_det_scores[new_ids_keep], + [None] * len(new_ids), + query_ids=new_ids) + else: + new_track_ids = self.add_tracks( + current_det_boxes[current_det_keep], + current_det_scores[current_det_keep], + [None] * len(current_det_boxes[current_det_keep]), + query_ids=list(current_det_keep.nonzero().squeeze(-1).cpu().numpy())) + + active_tracks = [t for t in self.tracks if t.count_inactive == 0] + # NMS to filter duplicates + if self.detection_nms_thresh and len(active_tracks) > 0: + track_boxes = torch.stack([t.pos for t in active_tracks]) + track_scores = torch.stack([t.score for t in active_tracks]) + + new_track_mask = torch.tensor([ + True if t.id in new_track_ids else False for t in active_tracks]) + track_scores[~new_track_mask] = np.inf + keep = nms(track_boxes, track_scores, self.detection_nms_thresh) + + remove_track_ids = [track.id for i, track in enumerate(active_tracks) if i not in keep] + if len(remove_track_ids) > 0: + self._logger(f'REMOVE TRACK IDS (detection_nms_thresh={self.detection_nms_thresh}): {remove_track_ids}') + self.tracks = [t for t in self.tracks if t.id not in remove_track_ids] + + ## only record results for active tracks + track_token_propagation_mask = torch.zeros_like(current_det_scores) + for track in self.tracks: + if track.count_inactive > 0: continue + track_token_propagation_mask[track.query_id] = 1 + + if track.id not in self.results: + self.results[track.id] = {} + self.results[track.id][self.frame_index] = {} + self.results[track.id][self.frame_index]['bbox'] = track.pos.cpu().numpy() + self.results[track.id][self.frame_index]['score'] = track.score.cpu().numpy() + self.results[track.id][self.frame_index]['query_idx'] = track.query_id + + self.prev_outs = { + 'track_query_boxes': outputs['pred_boxes'][0].detach(), + 'track_query_hs_embeds': outputs['hs_embed'][0].to(self.device), + 'track_token_propagation_mask': track_token_propagation_mask + } + self.frame_index += 1 + self._prev_blob = blob + + + def step(self, blob): + """This function should be called every timestep to perform tracking with a blob + containing the image information. + """ + if self.obj_detector.track_query_propagation_strategy == 'consistent_pairing': + self.consistent_pairing_step(blob) + return + + ## trackformer tracking on inference + self._logger(f'FRAME: {self.frame_index + 1}') + if self.inactive_tracks: + self._logger(f'INACTIVE TRACK IDS: {[t.id for t in self.inactive_tracks]}') + + # add current position to last_pos list + for track in self.tracks: + track.last_pos.append(track.pos.clone()) + + img = blob['img'].to(self.device) + orig_size = blob['orig_size'].to(self.device) + + target = None + num_prev_track = len(self.tracks + self.inactive_tracks) + if num_prev_track: + track_query_boxes = torch.stack([ + t.pos for t in self.tracks + self.inactive_tracks], dim=0).cpu() + + track_query_boxes = box_xyxy_to_cxcywh(track_query_boxes) + track_query_boxes = track_query_boxes / torch.tensor([ + orig_size[0, 1], orig_size[0, 0], + orig_size[0, 1], orig_size[0, 0]], dtype=torch.float32) + + target = {'track_query_boxes': track_query_boxes} + + target['image_id'] = torch.tensor([1]).to(self.device) + target['track_query_hs_embeds'] = torch.stack([ + t.hs_embed[-1] for t in self.tracks + self.inactive_tracks], dim=0) + + target = {k: v.to(self.device) for k, v in target.items()} + target = [target] + + outputs, *_ = self.obj_detector(img, target) + # check_prediction(img, outputs, frame_id=self.frame_index+1) + + hs_embeds = outputs['hs_embed'][0] + + results = self.obj_detector_post['bbox'](outputs, orig_size) + if "segm" in self.obj_detector_post: + results = self.obj_detector_post['segm']( + results, + outputs, + orig_size, + blob["size"].to(self.device), + return_probs=True) + result = results[0] + + if 'masks' in result: + result['masks'] = result['masks'].squeeze(dim=1) + + boxes = clip_boxes_to_image(result['boxes'], orig_size[0]) + + # TRACKS + if num_prev_track: + track_scores = result['scores'][:-self.num_object_queries] + track_boxes = boxes[:-self.num_object_queries] + + if 'masks' in result: + track_masks = result['masks'][:-self.num_object_queries] + if self.generate_attention_maps: + track_attention_maps = self.attention_data['maps'][:-self.num_object_queries] + + track_keep = torch.logical_and( + track_scores > self.track_obj_score_thresh, + result['labels'][:-self.num_object_queries] == 0) + + tracks_to_inactive = [] + tracks_from_inactive = [] + + for i, track in enumerate(self.tracks): + if track_keep[i]: + track.score = track_scores[i] + track.hs_embed.append(hs_embeds[i]) + track.pos = track_boxes[i] + + if 'masks' in result: + track.mask = track_masks[i] + if self.generate_attention_maps: + track.attention_map = track_attention_maps[i] + else: + tracks_to_inactive.append(track) + + track_keep = torch.logical_and( + track_scores > self.reid_score_thresh, + result['labels'][:-self.num_object_queries] == 0) + + # reid queries + for i, track in enumerate(self.inactive_tracks, start=len(self.tracks)): + if track_keep[i]: + track.score = track_scores[i] + track.hs_embed.append(hs_embeds[i]) + track.pos = track_boxes[i] + + if 'masks' in result: + track.mask = track_masks[i] + if self.generate_attention_maps: + track.attention_map = track_attention_maps[i] + + tracks_from_inactive.append(track) + + if tracks_to_inactive: + self._logger( + f'NEW INACTIVE TRACK IDS ' + f'(track_obj_score_thresh={self.track_obj_score_thresh}): ' + f'{[t.id for t in tracks_to_inactive]}') + + self.num_reids += len(tracks_from_inactive) + for track in tracks_from_inactive: + self.inactive_tracks.remove(track) + self.tracks.append(track) + + self.tracks_to_inactive(tracks_to_inactive) + # self.tracks = [ + # track for track in self.tracks + # if track not in tracks_to_inactive] + + if self.track_nms_thresh and self.tracks: + track_boxes = torch.stack([t.pos for t in self.tracks]) + track_scores = torch.stack([t.score for t in self.tracks]) + + keep = nms(track_boxes, track_scores, self.track_nms_thresh) + remove_tracks = [ + track for i, track in enumerate(self.tracks) + if i not in keep] + + if remove_tracks: + self._logger( + f'REMOVE TRACK IDS (track_nms_thresh={self.track_nms_thresh}): ' + f'{[track.id for track in remove_tracks]}') + + # self.tracks_to_inactive(remove_tracks) + self.tracks = [ + track for track in self.tracks + if track not in remove_tracks] + + # NEW DETS + new_det_scores = result['scores'][-self.num_object_queries:] + new_det_boxes = boxes[-self.num_object_queries:] + new_det_hs_embeds = hs_embeds[-self.num_object_queries:] + + if 'masks' in result: + new_det_masks = result['masks'][-self.num_object_queries:] + if self.generate_attention_maps: + new_det_attention_maps = self.attention_data['maps'][-self.num_object_queries:] + + new_det_keep = torch.logical_and( + new_det_scores > self.detection_obj_score_thresh, + result['labels'][-self.num_object_queries:] == 0) ## 新检测的目标,只从后300个里面去找 + + new_det_boxes = new_det_boxes[new_det_keep] + new_det_scores = new_det_scores[new_det_keep] + new_det_hs_embeds = new_det_hs_embeds[new_det_keep] + + if 'masks' in result: + new_det_masks = new_det_masks[new_det_keep] + if self.generate_attention_maps: + new_det_attention_maps = new_det_attention_maps[new_det_keep] + + # public detection ## self.public_detections=False 时,不会做任何改变 + public_detections_mask = self.public_detections_mask( + new_det_boxes, blob['dets'][0]) + + new_det_boxes = new_det_boxes[public_detections_mask] + new_det_scores = new_det_scores[public_detections_mask] + new_det_hs_embeds = new_det_hs_embeds[public_detections_mask] + if 'masks' in result: + new_det_masks = new_det_masks[public_detections_mask] + if self.generate_attention_maps: + new_det_attention_maps = new_det_attention_maps[public_detections_mask] + + # reid ## 具体什么作用 + reid_mask = self.reid( + new_det_boxes, + new_det_scores, + new_det_hs_embeds, + new_det_masks if 'masks' in result else None, + new_det_attention_maps if self.generate_attention_maps else None) + + new_det_boxes = new_det_boxes[reid_mask] + new_det_scores = new_det_scores[reid_mask] + new_det_hs_embeds = new_det_hs_embeds[reid_mask] + if 'masks' in result: + new_det_masks = new_det_masks[reid_mask] + if self.generate_attention_maps: + new_det_attention_maps = new_det_attention_maps[reid_mask] + + # final add track + new_track_ids = self.add_tracks( + new_det_boxes, + new_det_scores, + new_det_hs_embeds, + new_det_masks if 'masks' in result else None, + new_det_attention_maps if self.generate_attention_maps else None) + + # NMS + if self.detection_nms_thresh and self.tracks: + track_boxes = torch.stack([t.pos for t in self.tracks]) + track_scores = torch.stack([t.score for t in self.tracks]) + + new_track_mask = torch.tensor([ + True if t.id in new_track_ids + else False + for t in self.tracks]) + track_scores[~new_track_mask] = np.inf + + keep = nms(track_boxes, track_scores, self.detection_nms_thresh) + remove_tracks = [track for i, track in enumerate(self.tracks) if i not in keep] + + if remove_tracks: + self._logger( + f'REMOVE TRACK IDS (detection_nms_thresh={self.detection_nms_thresh}): ' + f'{[track.id for track in remove_tracks]}') + + self.tracks = [track for track in self.tracks if track not in remove_tracks] + + #################### + # Generate Results # + #################### + + if 'masks' in result and self.tracks: + track_mask_probs = torch.stack([track.mask for track in self.tracks]) + index_map = torch.arange(track_mask_probs.size(0))[:, None, None] + index_map = index_map.expand_as(track_mask_probs) + + track_masks = torch.logical_and( + # remove background + track_mask_probs > 0.5, + # remove overlapp by largest probablity + index_map == track_mask_probs.argmax(dim=0) + ) + for i, track in enumerate(self.tracks): + track.mask = track_masks[i] + + for track in self.tracks: + if track.id not in self.results: + self.results[track.id] = {} + + self.results[track.id][self.frame_index] = {} + self.results[track.id][self.frame_index]['bbox'] = track.pos.cpu().numpy() + self.results[track.id][self.frame_index]['score'] = track.score.cpu().numpy() + + if track.mask is not None: + self.results[track.id][self.frame_index]['mask'] = track.mask.cpu().numpy() + if track.attention_map is not None: + self.results[track.id][self.frame_index]['attention_map'] = \ + track.attention_map.cpu().numpy() + + for t in self.inactive_tracks: + t.count_inactive += 1 + + self.frame_index += 1 + self._prev_blob = blob + + if self.reid_sim_only: + self.tracks_to_inactive(self.tracks) + + def get_results(self): + """Return current tracking results.""" + return self.results + + +class Track(object): + """This class contains all necessary for every individual track.""" + + def __init__(self, pos, score, track_id, hs_embed, + mask=None, attention_map=None): + self.id = track_id + self.pos = pos + self.last_pos = deque([pos.clone()]) + self.score = score + self.ims = deque([]) + self.count_inactive = 0 + self.gt_id = None + self.hs_embed = [hs_embed] + self.mask = mask + self.attention_map = attention_map + + # consistent_pairing mode + self.query_id = None + + def has_positive_area(self) -> bool: + """Checks if the current position of the track has + a valid, .i.e., positive area, bounding box.""" + return self.pos[2] > self.pos[0] and self.pos[3] > self.pos[1] + + def reset_last_pos(self) -> None: + """Reset last_pos to the current position of the track.""" + self.last_pos.clear() + self.last_pos.append(self.pos.clone()) diff --git a/src/trackformer/models/transformer.py b/src/trackformer/models/transformer.py new file mode 100644 index 0000000..52778d7 --- /dev/null +++ b/src/trackformer/models/transformer.py @@ -0,0 +1,337 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False, + track_attention=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, encoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec, + track_attention=track_attention) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed, tgt=None, prev_frame=None): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + mask = mask.flatten(1) + + if tgt is None: + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + + memory_prev_frame = None + if prev_frame is not None: + src_prev_frame = prev_frame['src'].flatten(2).permute(2, 0, 1) + pos_embed_prev_frame = prev_frame['pos'].flatten(2).permute(2, 0, 1) + mask_prev_frame = prev_frame['mask'].flatten(1) + + memory_prev_frame = self.encoder( + src_prev_frame, src_key_padding_mask=mask_prev_frame, pos=pos_embed_prev_frame) + + prev_frame['memory'] = memory_prev_frame + prev_frame['memory_key_padding_mask'] = mask_prev_frame + prev_frame['pos'] = pos_embed_prev_frame + + hs, hs_without_norm = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed, + prev_frame=prev_frame) + + return (hs.transpose(1, 2), + hs_without_norm.transpose(1, 2), + memory.permute(1, 2, 0).view(bs, c, h, w)) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, encoder_layer, num_layers, + norm=None, return_intermediate=False, track_attention=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + self.track_attention = track_attention + if self.track_attention: + self.layers_track_attention = _get_clones(encoder_layer, num_layers) + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + prev_frame: Optional[dict] = None): + output = tgt + + intermediate = [] + + if self.track_attention: + track_query_pos = query_pos[:-100].clone() + query_pos[:-100] = 0.0 + + for i, layer in enumerate(self.layers): + if self.track_attention: + track_output = output[:-100].clone() + + track_output = self.layers_track_attention[i]( + track_output, + src_mask=tgt_mask, + src_key_padding_mask=tgt_key_padding_mask, + pos=track_query_pos) + + output = torch.cat([track_output, output[-100:]]) + + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(output) + + if self.return_intermediate: + output = torch.stack(intermediate) + + if self.norm is not None: + return self.norm(output), output + return output, output + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + track_attention=args.track_attention + ) diff --git a/src/trackformer/util/__init__.py b/src/trackformer/util/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/src/trackformer/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/src/trackformer/util/box_ops.py b/src/trackformer/util/box_ops.py new file mode 100644 index 0000000..e201ef4 --- /dev/null +++ b/src/trackformer/util/box_ops.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + +def box_cxcywh_to_xywh(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), w, h] + return torch.stack(b, dim=-1) + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + +def box_xyxy_to_xywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [x0, y0, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/src/trackformer/util/misc.py b/src/trackformer/util/misc.py new file mode 100644 index 0000000..8c1c241 --- /dev/null +++ b/src/trackformer/util/misc.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import datetime +import os +import pickle +import subprocess +import time +from argparse import Namespace +from collections import defaultdict, deque +from typing import List, Optional + +import torch +import torch.distributed as dist +import torch.nn.functional as F +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor +from visdom import Visdom + +# if float(torchvision.__version__[:3]) < 0.7: +# from torchvision.ops import _new_empty_tensor +# from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, print_freq, delimiter="\t", vis=None, debug=False): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.vis = vis + self.print_freq = print_freq + self.debug = debug + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {meter}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, epoch=None, header=None): + i = 0 + if header is None: + header = 'Epoch: [{}]'.format(epoch) + + world_len_iterable = get_world_size() * len(iterable) + + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(world_len_iterable))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data_time: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % self.print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i * get_world_size(), world_len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i * get_world_size(), world_len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + + if self.vis is not None: + y_data = [self.meters[legend_name].median + for legend_name in self.vis.viz_opts['legend'] + if legend_name in self.meters] + y_data.append(iter_time.median) + + self.vis.plot(y_data, i * get_world_size() + (epoch - 1) * world_len_iterable) + + # DEBUG + # if i != 0 and i % self.print_freq == 0: + if self.debug and i % self.print_freq == 0: + break + + i += 1 + end = time.time() + + # if self.vis is not None: + # self.vis.reset() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + if isinstance(batch[0][0], list): # clip mode + assert len(batch) == 1 + batch = list(batch[0]) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + else: + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + +def frcnn_collate_fn(batch): + return tuple(zip(*batch)) + +def sttran_collate_fn(batch): + assert len(batch) == 1 # per video + image_batch = torch.stack(batch[0][0], dim=0) + target_list = batch[0][1] + return image_batch, target_list + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, _, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor] = None): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + if not is_master: + def line(*args, **kwargs): + pass + def images(*args, **kwargs): + pass + Visdom.line = line + Visdom.images = images + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print(f'| distributed init (rank {args.rank}): {args.dist_url}', flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + # torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +class DistributedWeightedSampler(torch.utils.data.DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, replacement=True): + super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank, shuffle) + + assert replacement + + self.replacement = replacement + + def __iter__(self): + iter_indices = super(DistributedWeightedSampler, self).__iter__() + if hasattr(self.dataset, 'sample_weight'): + indices = list(iter_indices) + + weights = torch.tensor([self.dataset.sample_weight(idx) for idx in indices]) + + g = torch.Generator() + g.manual_seed(self.epoch) + + weight_indices = torch.multinomial( + weights, self.num_samples, self.replacement, generator=g) + indices = torch.tensor(indices)[weight_indices] + + iter_indices = iter(indices.tolist()) + return iter_indices + + def __len__(self): + return self.num_samples + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +def nested_dict_to_namespace(dictionary): + namespace = dictionary + if isinstance(dictionary, dict): + namespace = Namespace(**dictionary) + for key, value in dictionary.items(): + setattr(namespace, key, nested_dict_to_namespace(value)) + return namespace diff --git a/src/trackformer/util/plot_utils.py b/src/trackformer/util/plot_utils.py new file mode 100644 index 0000000..effb90c --- /dev/null +++ b/src/trackformer/util/plot_utils.py @@ -0,0 +1,292 @@ +""" +Plotting utilities to visualize training logs. +""" +import os +from pathlib import Path, PurePath +import random + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +import matplotlib +import matplotlib.patches as patches +import matplotlib.patheffects as PathEffects +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from . import box_ops +from .misc import NestedTensor, nested_tensor_from_tensor_list + +MOT_obj_label_names = ['person', '_BG_'] +# # Action Genome classes +# PredicateClasses = ['looking_at', 'not_looking_at', 'unsure', 'above', 'beneath', 'in_front_of', 'behind', 'on_the_side_of', 'in', 'carrying', 'covered_by', 'drinking_from', 'eating', 'have_it_on_the_back', 'holding', 'leaning_on', 'lying_on', 'not_contacting', 'other_relationship', 'sitting_on', 'standing_on', 'touching', 'twisting', 'wearing', 'wiping', 'writing_on'] +# ObjClasses = ['person', 'bag', 'bed', 'blanket', 'book', 'box', 'broom', 'chair', 'closet/cabinet', 'clothes', 'cup/glass/bottle', 'dish', 'door', 'doorknob', 'doorway', 'floor', 'food', 'groceries', 'laptop', 'light', 'medicine', 'mirror', 'paper/notebook', 'phone/camera', 'picture', 'pillow', 'refrigerator', 'sandwich', 'shelf', 'shoe', 'sofa/couch', 'table', 'television', 'towel', 'vacuum', 'window'] +# VidHOI classes +PredicateClasses = ['lean_on', 'watch', 'above', 'next_to', 'behind', 'away', 'towards', 'in_front_of', 'hit', 'hold', 'wave', 'pat', 'carry', 'point_to', 'touch', 'play(instrument)', 'release', 'ride', 'grab', 'lift', 'use', 'press', 'inside', 'caress', 'pull', 'get_on', 'cut', 'hug', 'bite', 'open', 'close', 'throw', 'kick', 'drive', 'get_off', 'push', 'wave_hand_to', 'feed', 'chase', 'kiss', 'speak_to', 'beneath', 'smell', 'clean', 'lick', 'squeeze', 'shake_hand_with', 'knock', 'hold_hand_of', 'shout_at'] +ObjClasses = ['person', 'car', 'guitar', 'chair', 'handbag', 'toy', 'baby_seat', 'cat', 'bottle', 'backpack', 'motorcycle', 'ball/sports_ball', 'laptop', 'table', 'surfboard', 'camera', 'sofa', 'screen/monitor', 'bicycle', 'vegetables', 'dog', 'fruits', 'cake', 'cellphone', 'cup', 'bench', 'snowboard', 'skateboard', 'bread', 'bus/truck', 'ski', 'suitcase', 'stool', 'bat', 'elephant', 'fish', 'baby_walker', 'dish', 'watercraft', 'scooter', 'pig', 'refrigerator', 'horse', 'crab', 'bird', 'piano', 'cattle/cow', 'lion', 'chicken', 'camel', 'electric_fan', 'toilet', 'sheep/goat', 'rabbit', 'train', 'penguin', 'hamster/rat', 'snake', 'frisbee', 'aircraft', 'oven', 'racket', 'faucet', 'antelope', 'duck', 'stop_sign', 'sink', 'kangaroo', 'stingray', 'turtle', 'tiger', 'crocodile', 'bear', 'microwave', 'traffic_light', 'panda', 'leopard', 'squirrel'] + +# COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], +# [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933],] +cmap = matplotlib.cm.get_cmap('hsv') +COLORS = [cmap(idx/300) for idx in random.sample(range(0, 300), 300)] + +def check_annotation(samples, annotations, mode='train', idx=0): + img_tensors, img_masks = samples.decompose() + h, w = (img_masks[idx].float() < 1).nonzero(as_tuple=False).max(0)[0].cpu() + 1 + + img_tensor = img_tensors[idx,:,:h,:w].cpu().permute(1,2,0) + img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min()) + + res = annotations[idx] + org_h, org_w = res['orig_size'].cpu().float() + boxes = res['boxes'].cpu() + if mode == 'train': + boxes = box_ops.box_cxcywh_to_xyxy(boxes) + boxes = boxes * torch.tensor([w, h, w, h]).unsqueeze(0) + else: + boxes = boxes * torch.tensor([w/org_w, h/org_h, w/org_w, h/org_h]).unsqueeze(0) + obj_show_names = [] + for ind, (x, tid) in enumerate(zip(res['labels'], res['track_ids'])): + obj_show_names.append(f"{ObjClasses[x]}_{tid}") + + # draw images + plt.imshow(img_tensor) + for ind, bbox in enumerate(boxes): + x1, y1, x2, y2 = bbox + rect = patches.Rectangle((x1,y1), x2-x1+1, y2-y1+1, linewidth=1, edgecolor='r', facecolor='none') + plt.gca().add_patch(rect) + txt = plt.text(x1, y1, obj_show_names[ind], color='black') + txt.set_path_effects([PathEffects.withStroke(linewidth=5, foreground='w')]) + plt.gca().yaxis.set_label_position("right") + plt.title(f"image_id={annotations[idx]['image_id'].item()}") + + if 'relation_map' in res: + rels = res['relation_map'].nonzero(as_tuple=False).cpu().numpy() + rel_strs = '' + for i, rel in enumerate(rels): + rel_strs += f"{obj_show_names[rel[0]]} --{PredicateClasses[rel[2]]}--> {obj_show_names[rel[1]]}\n" + + print(f"image_id={annotations[idx]['image_id'].item()}:\n", rel_strs, '\n') + plt.xlabel(rel_strs, rotation=0, fontsize=6) + plt.text(5, 250, rel_strs, fontsize=6, color='red') + + plt.axis('off') + plt.tight_layout() + plt.show() + +def check_prediction(samples, results, threshold=0.9, targets=None, idx=0, frame_id=None, top_pred_rel_pairs=None, save_fig_dir=None): + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + pil_imgs, masks = samples.decompose() + pil_img, mask= pil_imgs[idx], masks[idx] + + pil_img = ((pil_img - pil_img.min()) / (pil_img.max() - pil_img.min())).permute(1,2,0).cpu().numpy() + h, w = (~mask).float().nonzero(as_tuple=False).max(0)[0] + 1 + pil_img = pil_img[:h, :w] + + if isinstance(results, list): + boxes = box_ops.box_cxcywh_to_xyxy(results[idx]['pred_boxes'].detach().cpu()[0]) * torch.tensor([w,h,w,h]) + box_scores, box_labels = results[idx]['pred_logits'].softmax(-1)[..., :-1].detach().cpu()[0].max(-1) + else: + boxes = box_ops.box_cxcywh_to_xyxy(results['pred_boxes'][idx].detach().cpu()) * torch.tensor([w,h,w,h]) + box_scores, box_labels = results['pred_logits'][idx].softmax(-1)[..., :-1].detach().cpu().max(-1) + + # print relation predictions + prop_obj_ids, top_rel_str = [], '' + if 'pred_rel_pairs' in results: + img_rel_proposal_pairs = results['pred_rel_pairs'][idx] + img_rel_prop_scores = results['pred_relation_exists'][idx].sigmoid() + img_rel_predicate_scores = results['pred_relations'][idx].sigmoid() + if top_pred_rel_pairs is not None: + print(f'Top predicted relations for image_id={targets[idx]["image_id"].item()}:') + show_rel_triplets = top_pred_rel_pairs[idx][:10] # top 10 + prop_obj_ids = show_rel_triplets[:, :2].flatten() + for sub, obj, predicate in show_rel_triplets: + prop_idx = img_rel_proposal_pairs.tolist().index([sub, obj]) + predicate_score = img_rel_predicate_scores[prop_idx, predicate] + top_rel_str += f"{ObjClasses[box_labels[sub]]}_{sub}({box_scores[sub]:.2f}) -- {PredicateClasses[predicate]}({predicate_score:.2f}) --> {ObjClasses[box_labels[obj]]}_{obj}({box_scores[obj]:.2f})\n" + print(top_rel_str) + else: + print(f'Top predicted pairs for image_id={targets[idx]["image_id"].item()}:') + prop_obj_ids = img_rel_proposal_pairs.view(-1).tolist() + for p, ps, pred_scores in zip(img_rel_proposal_pairs, img_rel_prop_scores, img_rel_predicate_scores): + top_predicates_scores, top_predicates = pred_scores.sort(descending=True) + top_predicates_str = ', '.join([f"{PredicateClasses[top_predicates[k]]} ({top_predicates_scores[k]:.2f})" for k in range(3)]) + print(f'\033[94m{ObjClasses[box_labels[p[0]]]}_{p[0]}-{ObjClasses[box_labels[p[1]]]}_{p[1]} ({ps:.2f}):\t \033[92m{top_predicates_str}\033[0m') + + ######## plt detected boxes ########## + plt.imshow(pil_img, alpha=0.5) + for id, (sc, l, (xmin, ymin, xmax, ymax), c) in enumerate(zip(box_scores, box_labels, boxes, COLORS)): + # if id in prop_obj_ids or sc > threshold: + # if sc > threshold: + if id in prop_obj_ids: + plt.gca().add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c)) + # text = f'{str(id)}_l={l}({sc:0.2f})' + text = f'{ObjClasses[l]}_{str(id)}' + plt.text(xmin, ymin, text, fontsize=18, bbox=dict(facecolor=c, alpha=0.6)) + + if len(top_rel_str) > 0: plt.text(5, 250, top_rel_str, fontsize=6, color='red') + if targets is not None: plt.title(f"image_id={targets[idx]['image_id'].item()}") + if frame_id is not None: plt.title(f"frame_id={frame_id}") + plt.axis('off') + plt.tight_layout() + + if save_fig_dir is not None: + if not os.path.isdir(save_fig_dir): + os.system(f'mkdir -p {save_fig_dir}') + plt.savefig(f"{save_fig_dir}/image_id={targets[idx]['image_id'].item()}.png") + plt.show() + +def check_sttran_prediction(images, rel_outs, box_preds, targets, top_pred_rel_pairs, idx=0, save_fig_dir=None): + pil_img = images[idx].permute(1,2,0).cpu().numpy() + + # box predictions + real_image_id = targets[idx]['image_id'].item() + boxes, box_scores, box_labels = box_preds[real_image_id]['boxes'], box_preds[real_image_id]['scores'], box_preds[real_image_id]['labels'] + + # relation predictions + prop_obj_ids, top_rel_str = [], '' + if top_pred_rel_pairs is not None: + img_rel_proposal_pairs = rel_outs['pred_rel_pairs'][idx] + img_rel_predicate_scores = rel_outs['pred_relations'][idx].sigmoid() + + print(f'Top predicted relations for image_id={targets[idx]["image_id"].item()}:') + show_rel_triplets = top_pred_rel_pairs[idx][:10] + prop_obj_ids = show_rel_triplets[:, :2].flatten() + for sub, obj, predicate in show_rel_triplets: + prop_idx = img_rel_proposal_pairs.tolist().index([sub, obj]) + predicate_score = img_rel_predicate_scores[prop_idx, predicate] + top_rel_str += f"{ObjClasses[box_labels[sub]]}_{sub}({box_scores[sub]:.2f}) -- {PredicateClasses[predicate]}({predicate_score:.2f}) --> {ObjClasses[box_labels[obj]]}_{obj}({box_scores[obj]:.2f})\n" + + print(top_rel_str) + + ######## plt detected boxes ########## + plt.imshow(pil_img, alpha=0.5) + for id, (sc, l, (xmin, ymin, xmax, ymax)) in enumerate(zip(box_scores, box_labels, boxes)): + c = COLORS[id+idx*10] # randomize color + if id in prop_obj_ids: + plt.gca().add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c)) + # text = f'{ObjClasses[l]}_{str(id)}' + text = f'{ObjClasses[l]}' + plt.text(xmin, ymin, text, fontsize=18, bbox=dict(facecolor=c, alpha=0.6)) + + # show image + # if len(top_rel_str) > 0: plt.text(5, 70, top_rel_str, fontsize=6, color='red') + # if targets is not None: plt.title(f"image_id={targets[idx]['image_id'].item()}") + plt.axis('off') + plt.tight_layout() + if save_fig_dir is not None: + if not os.path.isdir(save_fig_dir): + os.system(f'mkdir -p {save_fig_dir}') + plt.savefig(f"{save_fig_dir}/notext_image_id={targets[idx]['image_id'].item()}.png") + plt.show() + +def fig_to_numpy(fig): + w, h = fig.get_size_inches() * fig.dpi + w = int(w.item()) + h = int(h.item()) + canvas = FigureCanvas(fig) + canvas.draw() + numpy_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3) + return np.copy(numpy_image) + + +def get_vis_win_names(vis_dict): + vis_win_names = { + outer_k: { + inner_k: inner_v.win + for inner_k, inner_v in outer_v.items() + } + for outer_k, outer_v in vis_dict.items() + } + return vis_win_names + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # verify valid dir(s) and that every item in list is Path object + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if dir.exists(): + continue + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs diff --git a/src/trackformer/util/track_utils.py b/src/trackformer/util/track_utils.py new file mode 100644 index 0000000..9fddc08 --- /dev/null +++ b/src/trackformer/util/track_utils.py @@ -0,0 +1,414 @@ +######################################### +# Still ugly file with helper functions # +######################################### + +import os +from collections import defaultdict +from os import path as osp + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import motmetrics as mm +import numpy as np +import torch +import torchvision.transforms.functional as F +import tqdm +from cycler import cycler as cy +from matplotlib import colors +from scipy.interpolate import interp1d + +# matplotlib.use('Agg') + + +# From frcnn/utils/bbox.py +def bbox_overlaps(boxes, query_boxes): + """ + Parameters + ---------- + boxes: (N, 4) ndarray or tensor or variable + query_boxes: (K, 4) ndarray or tensor or variable + Returns + ------- + overlaps: (N, K) overlap between boxes and query_boxes + """ + if isinstance(boxes, np.ndarray): + boxes = torch.from_numpy(boxes) + query_boxes = torch.from_numpy(query_boxes) + out_fn = lambda x: x.numpy() # If input is ndarray, turn the overlaps back to ndarray when return + else: + out_fn = lambda x: x + + box_areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1) + query_areas = (query_boxes[:, 2] - query_boxes[:, 0] + 1) * (query_boxes[:, 3] - query_boxes[:, 1] + 1) + + iw = (torch.min(boxes[:, 2:3], query_boxes[:, 2:3].t()) - torch.max(boxes[:, 0:1], + query_boxes[:, 0:1].t()) + 1).clamp(min=0) + ih = (torch.min(boxes[:, 3:4], query_boxes[:, 3:4].t()) - torch.max(boxes[:, 1:2], + query_boxes[:, 1:2].t()) + 1).clamp(min=0) + ua = box_areas.view(-1, 1) + query_areas.view(1, -1) - iw * ih + overlaps = iw * ih / ua + return out_fn(overlaps) + + +def rand_cmap(nlabels, type='bright', first_color_black=True, last_color_black=False, verbose=False): + """ + Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks + :param nlabels: Number of labels (size of colormap) + :param type: 'bright' for strong colors, 'soft' for pastel colors + :param first_color_black: Option to use first color as black, True or False + :param last_color_black: Option to use last color as black, True or False + :param verbose: Prints the number of labels and shows the colormap. True or False + :return: colormap for matplotlib + """ + import colorsys + + import numpy as np + from matplotlib.colors import LinearSegmentedColormap + + + if type not in ('bright', 'soft'): + print ('Please choose "bright" or "soft" for type') + return + + if verbose: + print('Number of labels: ' + str(nlabels)) + + # Generate color map for bright colors, based on hsv + if type == 'bright': + randHSVcolors = [(np.random.uniform(low=0.0, high=1), + np.random.uniform(low=0.2, high=1), + np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] + + # Convert HSV list to RGB + randRGBcolors = [] + for HSVcolor in randHSVcolors: + randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) + + if first_color_black: + randRGBcolors[0] = [0, 0, 0] + + if last_color_black: + randRGBcolors[-1] = [0, 0, 0] + + random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) + + # Generate soft pastel colors, by limiting the RGB spectrum + if type == 'soft': + low = 0.6 + high = 0.95 + randRGBcolors = [(np.random.uniform(low=low, high=high), + np.random.uniform(low=low, high=high), + np.random.uniform(low=low, high=high)) for i in range(nlabels)] + + if first_color_black: + randRGBcolors[0] = [0, 0, 0] + + if last_color_black: + randRGBcolors[-1] = [0, 0, 0] + random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) + + # Display colorbar + if verbose: + from matplotlib import colorbar, colors + from matplotlib import pyplot as plt + fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) + + bounds = np.linspace(0, nlabels, nlabels + 1) + norm = colors.BoundaryNorm(bounds, nlabels) + + colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, + boundaries=bounds, format='%1i', orientation=u'horizontal') + + return random_colormap + + +def plot_sequence(tracks, data_loader, output_dir, write_images, generate_attention_maps): + """Plots a whole sequence + + Args: + tracks (dict): The dictionary containing the track dictionaries in the form tracks[track_id][frame] = bb + db (torch.utils.data.Dataset): The dataset with the images belonging to the tracks (e.g. MOT_Sequence object) + output_dir (String): Directory where to save the resulting images + """ + if not osp.exists(output_dir): + os.makedirs(output_dir) + + # infinite color loop + # cyl = cy('ec', COLORS) + # loop_cy_iter = cyl() + # styles = defaultdict(lambda: next(loop_cy_iter)) + + # cmap = plt.cm.get_cmap('hsv', ) + cmap = rand_cmap(len(tracks)*10, type='bright', first_color_black=False, last_color_black=False) + + # if generate_attention_maps: + # attention_maps_per_track = { + # track_id: (np.concatenate([t['attention_map'] for t in track.values()]) + # if len(track) > 1 + # else list(track.values())[0]['attention_map']) + # for track_id, track in tracks.items()} + # attention_map_thresholds = { + # track_id: np.histogram(maps, bins=2)[1][1] + # for track_id, maps in attention_maps_per_track.items()} + + # _, attention_maps_bin_edges = np.histogram(all_attention_maps, bins=2) + print('Saving images...') + for frame_id, frame_data in enumerate(tqdm.tqdm(data_loader)): + img_path = frame_data['img_path'][0] + img = cv2.imread(img_path)[:, :, (2, 1, 0)] + height, width, _ = img.shape + + fig = plt.figure() + fig.set_size_inches(width / 96, height / 96) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + fig.add_axes(ax) + ax.imshow(img) + + if generate_attention_maps: + attention_map_img = np.zeros((height, width, 4)) + + for track_id, track_data in tracks.items(): + if frame_id in track_data.keys(): + bbox = track_data[frame_id]['bbox'] + + if 'mask' in track_data[frame_id]: + mask = track_data[frame_id]['mask'] + mask = np.ma.masked_where(mask == 0.0, mask) + + ax.imshow(mask, alpha=0.5, cmap=colors.ListedColormap([cmap(track_id)])) + + annotate_color = 'white' + else: + ax.add_patch( + plt.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + fill=False, + linewidth=2.0, + color=cmap(track_id), + )) + + annotate_color = cmap(track_id) + + if write_images == 'debug': + ax.annotate( + f"tid={track_id} ({track_data[frame_id]['score']:.2f})", + (bbox[0], bbox[1]), + color=annotate_color, weight='bold', fontsize=12, ha='center', va='center') + + if 'attention_map' in track_data[frame_id]: + attention_map = track_data[frame_id]['attention_map'] + attention_map = cv2.resize(attention_map, (width, height)) + + # attention_map_img = np.ones((height, width, 4)) * cmap(track_id) + # # max value will be at 0.75 transparency + # attention_map_img[:, :, 3] = attention_map * 0.75 / attention_map.max() + + # _, bin_edges = np.histogram(attention_map, bins=2) + # attention_map_img[:, :][attention_map < bin_edges[1]] = 0.0 + + # attention_map_img += attention_map_img + + # _, bin_edges = np.histogram(attention_map, bins=2) + + norm_attention_map = attention_map / attention_map.max() + + high_att_mask = norm_attention_map > 0.25 # bin_edges[1] + attention_map_img[:, :][high_att_mask] = cmap(track_id) + attention_map_img[:, :, 3][high_att_mask] = norm_attention_map[high_att_mask] * 0.5 + + # attention_map_img[:, :] += (np.tile(attention_map[..., np.newaxis], (1,1,4)) / attention_map.max()) * cmap(track_id) + # attention_map_img[:, :, 3] = 0.75 + + if generate_attention_maps: + ax.imshow(attention_map_img, vmin=0.0, vmax=1.0) + + plt.axis('off') + # plt.tight_layout() + plt.draw() + plt.savefig(osp.join(output_dir, osp.basename(img_path)), dpi=96) + plt.close() + + +def interpolate_tracks(tracks): + for i, track in tracks.items(): + frames = [] + x0 = [] + y0 = [] + x1 = [] + y1 = [] + + for f, data in track.items(): + frames.append(f) + x0.append(data['bbox'][0]) + y0.append(data['bbox'][1]) + x1.append(data['bbox'][2]) + y1.append(data['bbox'][3]) + + if frames: + x0_inter = interp1d(frames, x0) + y0_inter = interp1d(frames, y0) + x1_inter = interp1d(frames, x1) + y1_inter = interp1d(frames, y1) + + for f in range(min(frames), max(frames) + 1): + bbox = np.array([ + x0_inter(f), + y0_inter(f), + x1_inter(f), + y1_inter(f)]) + tracks[i][f]['bbox'] = bbox + else: + tracks[i][frames[0]]['bbox'] = np.array([ + x0[0], y0[0], x1[0], y1[0]]) + + return interpolated + + +def bbox_transform_inv(boxes, deltas): + # Input should be both tensor or both Variable and on the same device + if len(boxes) == 0: + return deltas.detach() * 0 + + widths = boxes[:, 2] - boxes[:, 0] + 1.0 + heights = boxes[:, 3] - boxes[:, 1] + 1.0 + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + dx = deltas[:, 0::4] + dy = deltas[:, 1::4] + dw = deltas[:, 2::4] + dh = deltas[:, 3::4] + + pred_ctr_x = dx * widths.unsqueeze(1) + ctr_x.unsqueeze(1) + pred_ctr_y = dy * heights.unsqueeze(1) + ctr_y.unsqueeze(1) + pred_w = torch.exp(dw) * widths.unsqueeze(1) + pred_h = torch.exp(dh) * heights.unsqueeze(1) + + pred_boxes = torch.cat( + [_.unsqueeze(2) for _ in [pred_ctr_x - 0.5 * pred_w, + pred_ctr_y - 0.5 * pred_h, + pred_ctr_x + 0.5 * pred_w, + pred_ctr_y + 0.5 * pred_h]], 2).view(len(boxes), -1) + return pred_boxes + + +def clip_boxes(boxes, im_shape): + """ + Clip boxes to image boundaries. + boxes must be tensor or Variable, im_shape can be anything but Variable + """ + if not hasattr(boxes, 'data'): + boxes_ = boxes.numpy() + + boxes = boxes.view(boxes.size(0), -1, 4) + boxes = torch.stack([ + boxes[:, :, 0].clamp(0, im_shape[1] - 1), + boxes[:, :, 1].clamp(0, im_shape[0] - 1), + boxes[:, :, 2].clamp(0, im_shape[1] - 1), + boxes[:, :, 3].clamp(0, im_shape[0] - 1) + ], 2).view(boxes.size(0), -1) + + return boxes + + +def get_center(pos): + x1 = pos[0, 0] + y1 = pos[0, 1] + x2 = pos[0, 2] + y2 = pos[0, 3] + return torch.Tensor([(x2 + x1) / 2, (y2 + y1) / 2]).cuda() + + +def get_width(pos): + return pos[0, 2] - pos[0, 0] + + +def get_height(pos): + return pos[0, 3] - pos[0, 1] + + +def make_pos(cx, cy, width, height): + return torch.Tensor([[ + cx - width / 2, + cy - height / 2, + cx + width / 2, + cy + height / 2 + ]]).cuda() + + +def warp_pos(pos, warp_matrix): + p1 = torch.Tensor([pos[0, 0], pos[0, 1], 1]).view(3, 1) + p2 = torch.Tensor([pos[0, 2], pos[0, 3], 1]).view(3, 1) + p1_n = torch.mm(warp_matrix, p1).view(1, 2) + p2_n = torch.mm(warp_matrix, p2).view(1, 2) + return torch.cat((p1_n, p2_n), 1).view(1, -1).cuda() + + +def get_mot_accum(results, seq_loader): + mot_accum = mm.MOTAccumulator(auto_id=True) + + for frame_id, frame_data in enumerate(seq_loader): + gt = frame_data['gt'] + gt_ids = [] + if gt: + gt_boxes = [] + for gt_id, gt_box in gt.items(): + gt_ids.append(gt_id) + gt_boxes.append(gt_box[0]) + + gt_boxes = np.stack(gt_boxes, axis=0) + # x1, y1, x2, y2 --> x1, y1, width, height + gt_boxes = np.stack( + (gt_boxes[:, 0], + gt_boxes[:, 1], + gt_boxes[:, 2] - gt_boxes[:, 0], + gt_boxes[:, 3] - gt_boxes[:, 1]), axis=1) + else: + gt_boxes = np.array([]) + + track_ids = [] + track_boxes = [] + for track_id, track_data in results.items(): + if frame_id in track_data: + track_ids.append(track_id) + # frames = x1, y1, x2, y2, score + track_boxes.append(track_data[frame_id]['bbox']) + + if track_ids: + track_boxes = np.stack(track_boxes, axis=0) + # x1, y1, x2, y2 --> x1, y1, width, height + track_boxes = np.stack( + (track_boxes[:, 0], + track_boxes[:, 1], + track_boxes[:, 2] - track_boxes[:, 0], + track_boxes[:, 3] - track_boxes[:, 1]), axis=1) + else: + track_boxes = np.array([]) + + distance = mm.distances.iou_matrix(gt_boxes, track_boxes, max_iou=0.5) + + mot_accum.update( + gt_ids, + track_ids, + distance) + + return mot_accum + + +def evaluate_mot_accums(accums, names, generate_overall=True): + mh = mm.metrics.create() + summary = mh.compute_many( + accums, + metrics=mm.metrics.motchallenge_metrics, + names=names, + generate_overall=generate_overall,) + + str_summary = mm.io.render_summary( + summary, + formatters=mh.formatters, + namemap=mm.io.motchallenge_metric_names,) + return summary, str_summary diff --git a/src/trackformer/vis.py b/src/trackformer/vis.py new file mode 100644 index 0000000..93a9385 --- /dev/null +++ b/src/trackformer/vis.py @@ -0,0 +1,357 @@ +import copy +import logging + +import matplotlib.patches as mpatches +import numpy as np +import torch +import torchvision.transforms as T +from matplotlib import colors +from matplotlib import pyplot as plt +from visdom import Visdom + +from .util.plot_utils import fig_to_numpy + +logging.getLogger('visdom').setLevel(logging.CRITICAL) + + +class BaseVis(object): + + def __init__(self, viz_opts, update_mode='append', env=None, win=None, + resume=False, port=8097, server='http://localhost'): + self.viz_opts = viz_opts + self.update_mode = update_mode + self.win = win + if env is None: + env = 'main' + self.viz = Visdom(env=env, port=port, server=server) + # if resume first plot should not update with replace + self.removed = not resume + + def win_exists(self): + return self.viz.win_exists(self.win) + + def close(self): + if self.win is not None: + self.viz.close(win=self.win) + self.win = None + + def register_event_handler(self, handler): + self.viz.register_event_handler(handler, self.win) + + +class LineVis(BaseVis): + """Visdom Line Visualization Helper Class.""" + + def plot(self, y_data, x_label): + """Plot given data. + + Appends new data to exisiting line visualization. + """ + update = self.update_mode + # update mode must be None the first time or after plot data was removed + if self.removed: + update = None + self.removed = False + + if isinstance(x_label, list): + Y = torch.Tensor(y_data) + X = torch.Tensor(x_label) + else: + y_data = [d.cpu() if torch.is_tensor(d) + else torch.tensor(d) + for d in y_data] + + Y = torch.Tensor(y_data).unsqueeze(dim=0) + X = torch.Tensor([x_label]) + + win = self.viz.line(X=X, Y=Y, opts=self.viz_opts, win=self.win, update=update) + + if self.win is None: + self.win = win + self.viz.save([self.viz.env]) + + def reset(self): + #TODO: currently reset does not empty directly only on the next plot. + # update='remove' is not working as expected. + if self.win is not None: + # self.viz.line(X=None, Y=None, win=self.win, update='remove') + self.removed = True + + +class ImgVis(BaseVis): + """Visdom Image Visualization Helper Class.""" + + def plot(self, images): + """Plot given images.""" + + # images = [img.data if isinstance(img, torch.autograd.Variable) + # else img for img in images] + # images = [img.squeeze(dim=0) if len(img.size()) == 4 + # else img for img in images] + + self.win = self.viz.images( + images, + nrow=1, + opts=self.viz_opts, + win=self.win, ) + self.viz.save([self.viz.env]) + + +def vis_results(visualizer, img, result, target, tracking): + inv_normalize = T.Normalize( + mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255], + std=[1 / 0.229, 1 / 0.224, 1 / 0.255] + ) + + imgs = [inv_normalize(img).cpu()] + img_ids = [target['image_id'].item()] + for key in ['prev_image', 'random_image']: + if key in target: + imgs.append(inv_normalize(target[key]).cpu()) + img_ids.append(target[f'{key}_id'].item()) + + # img.shape=[3, H, W] + dpi = 96 + figure, axarr = plt.subplots(len(imgs)) + figure.tight_layout() + figure.set_dpi(dpi) + figure.set_size_inches( + imgs[0].shape[2] / dpi, + imgs[0].shape[1] * len(imgs) / dpi) + + if len(imgs) == 1: + axarr = [axarr] + + for ax, img, img_id in zip(axarr, imgs, img_ids): + ax.set_axis_off() + ax.imshow(img.permute(1, 2, 0).clamp(0, 1)) + + ax.text( + 0, 0, f'IMG_ID={img_id}', + fontsize=20, bbox=dict(facecolor='white', alpha=0.5)) + + num_track_queries = num_track_queries_with_id = 0 + if tracking: + num_track_queries = len(target['track_query_boxes']) + num_track_queries_with_id = len(target['track_query_match_ids']) + track_ids = target['track_ids'][target['track_query_match_ids']] + + keep = result['scores'].cpu() > result['scores_no_object'].cpu() + + cmap = plt.cm.get_cmap('hsv', len(keep)) + + prop_i = 0 + for box_id in range(len(keep)): + mask_value = 0 + if tracking: + mask_value = target['track_queries_match_mask'][box_id].item() + + rect_color = 'green' + offset = 0 + text = f"{result['scores'][box_id]:0.2f}" + if mask_value == 1: + offset = 50 + rect_color = 'blue' + text = ( + f"{track_ids[prop_i]}\n" + f"{text}\n" + f"{result['track_queries_with_id_iou'][prop_i]:0.2f}") + prop_i += 1 + elif mask_value == -1: + rect_color = 'red' + + if not keep[box_id]: + continue + + x1, y1, x2, y2 = result['boxes'][box_id] + + axarr[0].add_patch(plt.Rectangle( + (x1, y1), x2 - x1, y2 - y1, + fill=False, color=rect_color, linewidth=2)) + + axarr[0].text( + x1, y1 + offset, text, + fontsize=10, bbox=dict(facecolor='white', alpha=0.5)) + + if 'masks' in result: + mask = result['masks'][box_id][0].numpy() + mask = np.ma.masked_where(mask == 0.0, mask) + + axarr[0].imshow( + mask, alpha=0.5, cmap=colors.ListedColormap([cmap(box_id)])) + + query_keep = keep + if tracking: + query_keep = keep[target['track_queries_match_mask'] == 0] + + legend_handles = [mpatches.Patch( + color='green', + label=f"object queries ({query_keep.sum()}/{len(target['boxes']) - num_track_queries_with_id})\n- cls_score")] + + if num_track_queries: + legend_handles.append(mpatches.Patch( + color='blue', + label=f"track queries ({keep[target['track_queries_match_mask'] == 1].sum()}/{num_track_queries_with_id})\n- track_id\n- cls_score\n- iou")) + if num_track_queries_with_id != num_track_queries: + legend_handles.append(mpatches.Patch( + color='red', + label=f"false track queries ({keep[target['track_queries_match_mask'] == -1].sum()}/{num_track_queries - num_track_queries_with_id})")) + + axarr[0].legend(handles=legend_handles) + + i = 1 + for frame_prefix in ['prev', 'random']: + if f'{frame_prefix}_image_id' not in target or f'{frame_prefix}_boxes' not in target: + continue + + cmap = plt.cm.get_cmap('hsv', len(target[f'{frame_prefix}_track_ids'])) + + for j, track_id in enumerate(target[f'{frame_prefix}_track_ids']): + x1, y1, x2, y2 = target[f'{frame_prefix}_boxes'][j] + axarr[i].text( + x1, y1, f"track_id={track_id}", + fontsize=10, bbox=dict(facecolor='white', alpha=0.5)) + axarr[i].add_patch(plt.Rectangle( + (x1, y1), x2 - x1, y2 - y1, + fill=False, color='green', linewidth=2)) + + if f'{frame_prefix}_masks' in target: + mask = target[f'{frame_prefix}_masks'][j].cpu().numpy() + mask = np.ma.masked_where(mask == 0.0, mask) + + axarr[i].imshow( + mask, alpha=0.5, cmap=colors.ListedColormap([cmap(j)])) + i += 1 + + plt.subplots_adjust(wspace=0.01, hspace=0.01) + plt.axis('off') + + img = fig_to_numpy(figure).transpose(2, 0, 1) + plt.close() + + visualizer.plot(img) + + +def build_visualizers(args: dict): + visualizers = {} + visualizers['train'] = {} + visualizers['val'] = {} + + if args.eval_only or args.no_vis: + return visualizers + + env_name = str(args.output_dir).split('/')[-1] + + vis_kwargs = { + 'env': env_name, + 'resume': args.resume and args.resume_vis, + 'port': args.vis_port, + 'server': args.vis_server} + + # + # METRICS + # + + legend = [ + 'class_error', + 'loss', + 'loss_bbox', + 'loss_ce', + 'loss_giou', + 'loss_mask', + 'loss_dice', + 'cardinality_error_unscaled', + 'loss_bbox_unscaled', + 'loss_ce_unscaled', + 'loss_giou_unscaled', + 'loss_mask_unscaled', + 'loss_dice_unscaled', + 'lr', + 'lr_backbone', + 'iter_time' + ] + + if not args.masks: + legend.remove('loss_mask') + legend.remove('loss_mask_unscaled') + legend.remove('loss_dice') + legend.remove('loss_dice_unscaled') + + opts = dict( + title="TRAIN METRICS ITERS", + xlabel='ITERS', + ylabel='METRICS', + width=1000, + height=500, + legend=legend) + + # TRAIN + visualizers['train']['iter_metrics'] = LineVis(opts, **vis_kwargs) + + opts = copy.deepcopy(opts) + opts['title'] = "TRAIN METRICS EPOCHS" + opts['xlabel'] = "EPOCHS" + opts['legend'].remove('lr') + opts['legend'].remove('lr_backbone') + opts['legend'].remove('iter_time') + visualizers['train']['epoch_metrics'] = LineVis(opts, **vis_kwargs) + + # VAL + opts = copy.deepcopy(opts) + opts['title'] = "VAL METRICS EPOCHS" + opts['xlabel'] = "EPOCHS" + visualizers['val']['epoch_metrics'] = LineVis(opts, **vis_kwargs) + + # + # EVAL COCO + # + + legend = [ + 'BBOX AP IoU=0.50:0.95', + 'BBOX AP IoU=0.50', + 'BBOX AP IoU=0.75', + ] + + if args.masks: + legend.extend([ + 'MASK AP IoU=0.50:0.95', + 'MASK AP IoU=0.50', + 'MASK AP IoU=0.75']) + + if args.tracking and args.tracking_eval: + legend.extend(['MOTA', 'IDF1']) + + opts = dict( + title='TRAIN EVAL EPOCHS', + xlabel='EPOCHS', + ylabel='METRICS', + width=1000, + height=500, + legend=legend) + + # TRAIN + visualizers['train']['epoch_eval'] = LineVis(opts, **vis_kwargs) + + # VAL + opts = copy.deepcopy(opts) + opts['title'] = 'VAL EVAL EPOCHS' + visualizers['val']['epoch_eval'] = LineVis(opts, **vis_kwargs) + + # + # EXAMPLE RESULTS + # + + opts = dict( + title="TRAIN EXAMPLE RESULTS", + width=2500, + height=2500) + + # TRAIN + visualizers['train']['example_results'] = ImgVis(opts, **vis_kwargs) + + # VAL + opts = copy.deepcopy(opts) + opts['title'] = 'VAL EXAMPLE RESULTS' + visualizers['val']['example_results'] = ImgVis(opts, **vis_kwargs) + + return visualizers diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..a93f1b7 --- /dev/null +++ b/src/train.py @@ -0,0 +1,361 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import datetime +import os +import random +import time +from argparse import Namespace +from pathlib import Path + +import numpy as np +import sacred +import torch +import yaml +from torch.utils.data import DataLoader, DistributedSampler + +import trackformer.util.misc as utils +from trackformer.datasets import build_dataset +from trackformer.engine import evaluate, train_one_epoch, evaluate_video_sgg +from trackformer.models import build_model +from trackformer.util.misc import nested_dict_to_namespace +from trackformer.util.plot_utils import get_vis_win_names +from trackformer.vis import build_visualizers + +ex = sacred.Experiment('train') +ex.add_config('cfgs/train.yaml') +ex.add_named_config('deformable', 'cfgs/train_deformable.yaml') +ex.add_named_config('tracking', 'cfgs/train_tracking.yaml') +ex.add_named_config('crowdhuman', 'cfgs/train_crowdhuman.yaml') +ex.add_named_config('mot17', 'cfgs/train_mot17.yaml') +ex.add_named_config('mot17_cross_val', 'cfgs/train_mot17_cross_val.yaml') +ex.add_named_config('mots20', 'cfgs/train_mots20.yaml') +ex.add_named_config('coco_person_masks', 'cfgs/train_coco_person_masks.yaml') +ex.add_named_config('actiongenome', 'cfgs/train_actiongenome.yaml') +ex.add_named_config('vidhoi', 'cfgs/train_vidhoi.yaml') +ex.add_named_config('full_res', 'cfgs/train_full_res.yaml') +ex.add_named_config('focal_loss', 'cfgs/train_focal_loss.yaml') +ex.add_named_config('consistent_pairing', 'cfgs/consistent_pairing.yaml') +ex.add_named_config('hoi', 'cfgs/train_hoi.yaml') +ex.add_named_config('vsgg', 'cfgs/train_vsgg.yaml') + + +def train(args: Namespace) -> None: + print(args) + + utils.init_distributed_mode(args) + print("git:\n {}\n".format(utils.get_sha())) + + if args.debug: + # args.tracking_eval = False + args.num_workers = 0 + + if not args.deformable: + assert args.num_feature_levels == 1 + if args.tracking: + assert args.batch_size == 1 + + if args.tracking_eval: + assert 'mot' in args.dataset + + output_dir = Path(args.output_dir) + if args.output_dir: + output_dir.mkdir(parents=True, exist_ok=True) + + yaml.dump( + vars(args), + open(output_dir / 'config.yaml', 'w'), allow_unicode=True) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + + os.environ['PYTHONHASHSEED'] = str(seed) + # os.environ['NCCL_DEBUG'] = 'INFO' + # os.environ["NCCL_TREE_THRESHOLD"] = "0" + + np.random.seed(seed) + random.seed(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + + model, criterion, postprocessors = build_model(args) + model.to(device) + + visualizers = build_visualizers(args) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True) + model_without_ddp = model.module + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('NUM TOTAL MODEL PARAMS:', sum(p.numel() for p in model.parameters())) + print('NUM TRAINABLE MODEL PARAMS:', n_parameters) + + def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + scratch_params = [] + if args.resume: + print(f"Resume from model: {args.resume}") + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + + model_state_dict = model_without_ddp.state_dict() + checkpoint_state_dict = checkpoint['model'] + checkpoint_state_dict = { + k.replace('detr.', ''): v for k, v in checkpoint['model'].items()} + + resume_state_dict = {} + for k, v in model_state_dict.items(): + if k not in checkpoint_state_dict: + resume_value = v + scratch_params.append(k) + elif v.shape != checkpoint_state_dict[k].shape: + checkpoint_value = checkpoint_state_dict[k] + num_dims = len(checkpoint_value.shape) + + if 'norm' in k: + resume_value = checkpoint_value.repeat(2) + elif 'multihead_attn' in k or 'self_attn' in k: + resume_value = checkpoint_value.repeat(num_dims * (2, )) + elif 'linear1' in k or 'query_embed' in k: + if checkpoint_value.shape[1] * 2 == v.shape[1]: + # from hidden size 256 to 512 + resume_value = checkpoint_value.repeat(1, 2) + elif checkpoint_value.shape[0] * 5 == v.shape[0]: + # from 100 to 500 object queries + resume_value = checkpoint_value.repeat(5, 1) + elif checkpoint_value.shape[0] > v.shape[0]: + resume_value = checkpoint_value[:v.shape[0]] + elif checkpoint_value.shape[0] < v.shape[0]: + resume_value = v + else: + raise NotImplementedError + elif 'linear2' in k or 'input_proj' in k: + resume_value = checkpoint_value.repeat((2,) + (num_dims - 1) * (1, )) + elif 'class_embed' or 'semantic_embed' in k: + # person and no-object class + # resume_value = checkpoint_value[[1, -1]] + # resume_value = checkpoint_value[[0, -1]] + # resume_value = checkpoint_value[[1,]] + resume_value = v + else: + raise NotImplementedError(f"No rule for {k} with shape {v.shape}.") + + print(f"Load {k} {tuple(v.shape)} from resume model " + f"{tuple(checkpoint_value.shape)}.") + elif args.resume_shift_neuron and 'class_embed' in k: + checkpoint_value = checkpoint_state_dict[k] + # no-object class + resume_value = checkpoint_value.clone() + # no-object class + # resume_value[:-2] = checkpoint_value[1:-1].clone() + resume_value[:-1] = checkpoint_value[1:].clone() + resume_value[-2] = checkpoint_value[0].clone() + print(f"Load {k} {tuple(v.shape)} from resume model and " + "shift class embed neurons to start with label=0 at neuron=0.") + else: + resume_value = checkpoint_state_dict[k] + + resume_state_dict[k] = resume_value + + if args.masks and args.load_mask_head_from_model is not None: + checkpoint_mask_head = torch.load( + args.load_mask_head_from_model, map_location='cpu') + + for k, v in resume_state_dict.items(): + if (('bbox_attention' in k or 'mask_head' in k) + and v.shape == checkpoint_mask_head['model'][k].shape): + print(f'Load {k} {tuple(v.shape)} from mask head model.') + resume_state_dict[k] = checkpoint_mask_head['model'][k] + + model_without_ddp.load_state_dict(resume_state_dict) + + param_dicts = [ + {"params": [p for n, p in model_without_ddp.named_parameters() + if not match_name_keywords(n, args.lr_backbone_names + args.lr_linear_proj_names + ['layers_track_attention'] + scratch_params) and p.requires_grad], + "lr": args.lr,}, + {"params": [p for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad], + "lr": args.lr_backbone}, + {"params": [p for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad], + "lr": args.lr * args.lr_linear_proj_mult}, + {"params": [p for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, scratch_params) and p.requires_grad], + "lr": args.lr * 10}, + ] + print(f'scratch_params={scratch_params} with lr={args.lr * 10}') + if args.track_attention: + param_dicts.append({ + "params": [p for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, ['layers_track_attention']) and p.requires_grad], + "lr": args.lr_track}) + + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) + + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [args.lr_drop]) + + dataset_train = build_dataset(split='train', args=args) + dataset_val = build_dataset(split='val', args=args) + + if args.distributed: + sampler_train = utils.DistributedWeightedSampler(dataset_train) + # sampler_train = DistributedSampler(dataset_train) + sampler_val = DistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, args.batch_size, drop_last=True) + + data_loader_train = DataLoader( + dataset_train, + batch_sampler=batch_sampler_train, + collate_fn=utils.collate_fn, + num_workers=args.num_workers) + data_loader_val = DataLoader( + dataset_val, args.batch_size, + sampler=sampler_val, + drop_last=False, + collate_fn=utils.collate_fn, + num_workers=args.num_workers) + + # # check data + # imgs, tgts = data_loader_train.dataset.__getitem__(417) + # for j in range(len(tgts)): check_annotation(nested_tensor_from_tensor_list(imgs), tgts, idx=j) + + best_val_stats = None + if args.resume: + # RESUME OPTIM + if not args.eval_only and args.resume_optim: + if 'optimizer' in checkpoint: + for c_p, p in zip(checkpoint['optimizer']['param_groups'], param_dicts): + c_p['lr'] = p['lr'] + + optimizer.load_state_dict(checkpoint['optimizer']) + if 'lr_scheduler' in checkpoint: + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + if 'epoch' in checkpoint: + args.start_epoch = checkpoint['epoch'] + 1 + + best_val_stats = checkpoint['best_val_stats'] + + # RESUME VIS + if not args.eval_only and args.resume_vis and 'vis_win_names' in checkpoint: + for k, v in visualizers.items(): + for k_inner in v.keys(): + visualizers[k][k_inner].win = checkpoint['vis_win_names'][k][k_inner] + + if args.eval_only: + if args.video_sgg_eval: + evaluate_video_sgg(model, data_loader_val.dataset, device, args, postprocessors) + else: + _, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, device, + output_dir, visualizers['val'], args) + if args.output_dir: + utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + return + + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs + 1): + # TRAIN + if args.distributed: + sampler_train.set_epoch(epoch) + train_one_epoch( + model, criterion, postprocessors, data_loader_train, optimizer, device, epoch, + visualizers['train'], args) + + if args.eval_train: + random_transforms = data_loader_train.dataset._transforms + data_loader_train.dataset._transforms = data_loader_val.dataset._transforms + evaluate( + model, criterion, postprocessors, data_loader_train, device, + output_dir, visualizers['train'], args, epoch) + data_loader_train.dataset._transforms = random_transforms + + lr_scheduler.step() + + checkpoint_paths = [output_dir / 'checkpoint.pth'] + + # VAL + if epoch == 1 or not epoch % args.val_interval: + if args.video_sgg_eval: + print('Evaluate by videos:') + evaluate_video_sgg(model, data_loader_val.dataset, device, args, postprocessors) + else: + val_stats, _ = evaluate( + model, criterion, postprocessors, data_loader_val, device, + output_dir, visualizers['val'], args, epoch) + + checkpoint_paths = [output_dir / 'checkpoint.pth'] + # extra checkpoint before LR drop and every 100 epochs + # if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0: + # checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') + + # checkpoint for best validation stats + # stat_names = ['BBOX_AP_IoU_0_50-0_95', 'BBOX_AP_IoU_0_50', 'BBOX_AP_IoU_0_75'] + stat_names = ['BBOX_AP_IoU_0_50-0_95'] + if args.masks: + stat_names.extend(['MASK_AP_IoU_0_50-0_95', 'MASK_AP_IoU_0_50', 'MASK_AP_IoU_0_75']) + if args.tracking and args.tracking_eval: + # stat_names.extend(['MOTA', 'IDF1']) + stat_names.extend(['MOTA']) + + if best_val_stats is None: + best_val_stats = val_stats + best_val_stats = [best_stat if best_stat > stat else stat + for best_stat, stat in zip(best_val_stats, + val_stats)] + for b_s, s, n in zip(best_val_stats, val_stats, stat_names): + if b_s == s: + checkpoint_paths.append(output_dir / f"checkpoint_best_{n}.pth") + + # MODEL SAVING + if args.output_dir: + if args.save_model_interval and not epoch % args.save_model_interval: + checkpoint_paths.append(output_dir / f"checkpoint_epoch{epoch}.pth") + + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args, + 'vis_win_names': get_vis_win_names(visualizers), + 'best_val_stats': best_val_stats + }, checkpoint_path) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +@ex.main +def load_config(_config, _run): + """ We use sacred only for config loading from YAML files. """ + sacred.commands.print_config(_run) + + +if __name__ == '__main__': + # TODO: hierachical Namespacing for nested dict + config = ex.run_commandline().config + args = nested_dict_to_namespace(config) + # args.train = Namespace(**config['train']) + train(args) diff --git a/src/train_frcnn.py b/src/train_frcnn.py new file mode 100644 index 0000000..9d0825e --- /dev/null +++ b/src/train_frcnn.py @@ -0,0 +1,299 @@ +import sacred +import torch +import yaml +import os +import numpy as np +import random +import time +import datetime +from pathlib import Path +from torch.utils.data import DataLoader, DistributedSampler +import math +import sys +import torchvision.models.detection.mask_rcnn +import torchvision +from tqdm import tqdm +import motmetrics as mm +from STTran.sttran import build_frcnn +from base_trackers import CentroidTracker, SORT, IOUTracker, BYTETracker + +import trackformer.util.misc as utils +from trackformer.util.misc import nested_dict_to_namespace +from trackformer.datasets import build_dataset +from trackformer.datasets import get_coco_api_from_dataset +from trackformer.datasets.coco_eval import CocoEvaluator +from trackformer.util.box_ops import box_xyxy_to_xywh +from trackformer.util.track_utils import evaluate_mot_accums + +def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None): + model.train() + metric_logger = utils.MetricLogger(print_freq, delimiter=" ") + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = f"Epoch: [{epoch}]" + + # todo: same data format + for images, targets in metric_logger.log_every(data_loader, print_freq, header): + images = list(image.to(device) for image in images) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + for t in targets: t['labels'] += 1 # compatibility with torchvison frcnn (set __background__=0) + + with torch.cuda.amp.autocast(enabled=scaler is not None): + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + loss_value = losses_reduced.item() + + if not math.isfinite(loss_value): + print(f"Loss is {loss_value}, stopping training") + print(loss_dict_reduced) + sys.exit(1) + + optimizer.zero_grad() + if scaler is not None: + scaler.scale(losses).backward() + scaler.step(optimizer) + scaler.update() + else: + losses.backward() + optimizer.step() + + metric_logger.update(loss=losses_reduced, **loss_dict_reduced) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + return metric_logger + + +def _get_iou_types(model): + model_without_ddp = model + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_without_ddp = model.module + iou_types = ["bbox"] + if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): + iou_types.append("segm") + if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): + iou_types.append("keypoints") + return iou_types + + +@torch.no_grad() +def evaluate(model, data_loader, device, print_freq=100): + n_threads = torch.get_num_threads() + # FIXME remove this and make paste_masks_in_image run on the GPU + torch.set_num_threads(1) + cpu_device = torch.device("cpu") + model.eval() + metric_logger = utils.MetricLogger(print_freq, delimiter=" ") + header = "Test:" + + coco = get_coco_api_from_dataset(data_loader.dataset) + iou_types = _get_iou_types(model) + coco_evaluator = CocoEvaluator(coco, iou_types) + + for images, targets in metric_logger.log_every(data_loader, 100, header): + images = list(img.to(device) for img in images) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + model_time = time.time() + outputs = model(images) + + outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] + for o in outputs: o['labels'] -= 1 # for compatibility with evaluation code + model_time = time.time() - model_time + + res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + evaluator_time = time.time() + coco_evaluator.update(res) + evaluator_time = time.time() - evaluator_time + metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + coco_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + coco_evaluator.accumulate() + coco_evaluator.summarize() + torch.set_num_threads(n_threads) + return coco_evaluator + +@torch.no_grad() +def evaludate_mot(model, dataset_val, device, chosen_tracker, det_threshold=0.8): + print(f"Tracker: {chosen_tracker}") + model.eval() + mot_accums = [] + video_names, video_startend_idxs = dataset_val.sequence_infos() + if utils.get_world_size() > 1: # for multi-GPUs + video_names = video_names[utils.get_rank()::utils.get_world_size()] + video_startend_idxs = video_startend_idxs[utils.get_rank()::utils.get_world_size()] + + for vid, (video_name, v_seg_info) in enumerate(zip(video_names, video_startend_idxs)): + print(f'TRACK SEQ: {video_name} ({vid}/{len(video_names)})') + video_loader = DataLoader( + torch.utils.data.Subset(dataset_val, range(v_seg_info[0], v_seg_info[1]+1)), + collate_fn=utils.frcnn_collate_fn, + num_workers=args.num_workers + ) + + # trackers + if chosen_tracker == 'IOUTracker': + tracker, det_threshold = IOUTracker(iou_threshold=0.1), 0.8 + elif chosen_tracker == 'SORT': + tracker, det_threshold = SORT(iou_threshold=0.1), 0.8 + elif chosen_tracker == 'BYTETracker': + tracker, det_threshold = BYTETracker(iou_threshold=0.1), 0 # filter by confidence score inside + + # track one video + mot_accums.append(mm.MOTAccumulator(auto_id=True)) + for i, (frames, targets) in enumerate(tqdm(video_loader, file=sys.stdout)): + assert len(targets) == 1 + frames = list(img.to(device) for img in frames) + frame_det_out = model(frames)[0] + + # get trackers + kept_box_mask = frame_det_out['scores'] > det_threshold + tracks = tracker.update(box_xyxy_to_xywh(frame_det_out['boxes'][kept_box_mask]).cpu().numpy(), + frame_det_out['scores'][kept_box_mask].cpu().numpy(), + frame_det_out['labels'][kept_box_mask].cpu().numpy()-1) + pred_boxes, pred_track_ids, pred_labels = [], [], [] + for track in tracks: + frame_index, track_id, bbox_left, bbox_top, bbox_width, bbox_height, score, object_category, truncation, occlusion = track + pred_boxes.append([bbox_left, bbox_top, bbox_width, bbox_height]) + pred_track_ids.append(track_id) + pred_labels.append(object_category) + + # mot eval + gt_boxes, gt_track_ids, gt_labels = box_xyxy_to_xywh(targets[0]['boxes']).cpu().numpy(), targets[0]['track_ids'].tolist(), targets[0]['labels'].cpu().numpy() + distance = mm.distances.iou_matrix(gt_boxes, np.array(pred_boxes), max_iou=0.5) + if len(gt_labels) == 0 or len(pred_labels) == 0: + label_match = np.empty((0, 0)) + else: + label_match = (gt_labels[:, None] == np.array(pred_labels)[None, :]) + distance = np.where(label_match, distance, np.nan) + mot_accums[-1].update(gt_track_ids, pred_track_ids, distance) + + mot_accums_all, video_names_all = utils.all_gather(mot_accums), utils.all_gather(video_names) + eval_summary, eval_summary_str = evaluate_mot_accums(sum(mot_accums_all, []), sum(video_names_all, [])) + print(eval_summary_str) + print(f'#videos={len(sum(video_names_all, []))}') + +########## configs ########## +ex = sacred.Experiment('train') +ex.add_config('cfgs/train.yaml') +ex.add_named_config('vidhoi', 'cfgs/train_vidhoi.yaml') +ex.add_named_config('frcnn', 'cfgs/train_frcnn.yaml') + +@ex.main +def load_config(_config, _run): + """ We use sacred only for config loading from YAML files. """ + sacred.commands.print_config(_run) + +config = ex.run_commandline().config +args = nested_dict_to_namespace(config) +args.num_classes = 1+78 # backgrdound + objs + +########## environmental settings ########## +utils.init_distributed_mode(args) +print("git:\n {}\n".format(utils.get_sha())) + +output_dir = Path(args.output_dir) +if args.output_dir: + output_dir.mkdir(parents=True, exist_ok=True) + yaml.dump(vars(args), open(output_dir / 'config.yaml', 'w'), allow_unicode=True) +device = torch.device(args.device) + +seed = args.seed + utils.get_rank() +os.environ['PYTHONHASHSEED'] = str(seed) +np.random.seed(seed) +random.seed(seed) + +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +# torch.backends.cudnn.deterministic = True + +########## model ########## +model = build_frcnn(args) +model.to(device) + +model_without_ddp = model +if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True) + model_without_ddp = model.module + +n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) +print('NUM TOTAL MODEL PARAMS:', sum(p.numel() for p in model.parameters())) +print('NUM TRAINABLE MODEL PARAMS:', n_parameters) + +params = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) + +lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [args.lr_drop]) + +if args.resume: + print(f"Resume from model: {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + resume_state_dict = checkpoint['model'] + model_without_ddp.load_state_dict(resume_state_dict) + +########## dataset ########## +dataset_train = build_dataset(split='train', args=args) +dataset_val = build_dataset(split='val', args=args) + +if args.distributed: + sampler_train = utils.DistributedWeightedSampler(dataset_train) + # sampler_train = DistributedSampler(dataset_train) + sampler_val = DistributedSampler(dataset_val, shuffle=False) +else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + +batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, args.batch_size, drop_last=True) + +data_loader_train = DataLoader( + dataset_train, + batch_sampler=batch_sampler_train, + collate_fn=utils.frcnn_collate_fn, + num_workers=args.num_workers) +data_loader_val = DataLoader( + dataset_val, args.batch_size, + sampler=sampler_val, + drop_last=False, + collate_fn=utils.frcnn_collate_fn, + num_workers=args.num_workers) + +########## train & eval ########## +if args.eval_only: + evaluate(model, data_loader_val, device=device) + # evaludate_mot(model, dataset_val, device=device, chosen_tracker=args.sgg_postprocessing_tracker) +else: + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs + 1): + if args.distributed: + sampler_train.set_epoch(epoch) + + train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=500) + lr_scheduler.step() + if args.output_dir: + checkpoint = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "args": args, + "epoch": epoch, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, f'checkpoint_epoch{epoch}.pth')) + + # evaluate after every epoch + evaluate(model, data_loader_val, device=device) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) diff --git a/src/train_sttran.py b/src/train_sttran.py new file mode 100644 index 0000000..3fe0f1a --- /dev/null +++ b/src/train_sttran.py @@ -0,0 +1,264 @@ +import sacred +import torch +import yaml +import os +import numpy as np +import random +import time +import datetime +from pathlib import Path +from torch.utils.data import DataLoader, DistributedSampler +import math +import sys +import torchvision.models.detection.mask_rcnn +import torchvision +import copy + +import trackformer.util.misc as utils +from trackformer.util.misc import nested_dict_to_namespace +from trackformer.datasets import build_dataset +from trackformer.datasets import get_coco_api_from_dataset +from trackformer.datasets.coco_eval import CocoEvaluator +from trackformer.datasets.vidhoi_eval import VidHOIEvaluator +from STTran.sttran import build_sttran +from trackformer.util.plot_utils import check_sttran_prediction + +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): + model.train() + metric_logger = utils.MetricLogger(print_freq, delimiter=" ") + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = f"Epoch: [{epoch}]" + + # todo: same data format + for images, targets in metric_logger.log_every(data_loader, print_freq, header): + images = images.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + for t in targets: t['labels'] += 1 # compatibility with torchvison frcnn (set __background__=0) + + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + loss_value = losses_reduced.item() + + if not math.isfinite(loss_value): + print(f"Loss is {loss_value}, stopping training") + print(loss_dict_reduced) + sys.exit(1) + + optimizer.zero_grad() + losses.backward() + if args.clip_max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_max_norm) + optimizer.step() + + metric_logger.update(loss=losses_reduced, **loss_dict_reduced) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + return metric_logger + + +def _get_iou_types(model): + model_without_ddp = model + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_without_ddp = model.module + iou_types = ["bbox"] + if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): + iou_types.append("segm") + if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): + iou_types.append("keypoints") + return iou_types + + +@torch.no_grad() +def evaluate(model, data_loader, device, print_freq=100): + cpu_device = torch.device("cpu") + model.eval() + metric_logger = utils.MetricLogger(print_freq, delimiter=" ") + header = "Test:" + + coco = get_coco_api_from_dataset(data_loader.dataset) + iou_types = _get_iou_types(model) + coco_evaluator = CocoEvaluator(coco, iou_types) + + vidhoi_evaluator = VidHOIEvaluator(args) + + model_inference_time = 0 + for images, targets in metric_logger.log_every(data_loader, 10, header): + # if targets[0]['image_id'].item() < 280: continue + frame_tic = time.time() + images = images.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + for t in targets: t['labels'] += 1 # compatibility with torchvison frcnn (set __background__=0) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + model_time = time.time() + try: + outputs = model(images, targets) # requires pytorch<=1.5 + except: + # outputs = model(images, targets) + continue + model_time = time.time() - model_time + + ## evaluation + evaluator_time = time.time() + det_nums, det_res = outputs['box_nums'], {} + for t, b, l, s in zip(targets, outputs['boxes'][:, 1:].to(cpu_device).split(det_nums, dim=0), + (outputs['labels']-1).to(cpu_device).split(det_nums, dim=0), # -1 for compatibility with evaluation code + outputs['scores'].to(cpu_device).split(det_nums, dim=0)): + det_res[t['image_id'].item()] = {'boxes': b, 'labels': l, 'scores': s} + coco_evaluator.update(copy.deepcopy(det_res)) # detection evaluation + + rel_nums, rel_outs = outputs['rel_pair_nums'], {'pred_rel_pairs': [], 'pred_relations': []} + for idx, (rp, rl) in enumerate(zip(outputs['rel_pair_idxs'].split(rel_nums, dim=0), outputs['rel_logits'].split(rel_nums, dim=0))): + rel_outs['pred_rel_pairs'].append(rp - sum(det_nums[:idx])) + rel_outs['pred_relations'].append(rl) + top_pred_rel_pairs = vidhoi_evaluator.sttran_update(targets, rel_outs, box_preds=det_res) # relation evaluation + + # for idx in range(len(targets)): + # check_sttran_prediction(images, rel_outs, det_res, targets, top_pred_rel_pairs, idx=idx, save_fig_dir=f"{args.output_dir}/demo") + + evaluator_time = time.time() - evaluator_time + metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) + model_inference_time += time.time() - frame_tic + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + # accumulate predictions from all images + coco_evaluator.synchronize_between_processes() + coco_evaluator.accumulate() + coco_evaluator.summarize() + + vidhoi_evaluator.synchronize_between_processes() + vidhoi_evaluator.evaluate() + print(f"model_inference_time={model_inference_time}s") + print(f"Model: time_per_frame={model_inference_time/len(vidhoi_evaluator.gts)*1000 :.2f}ms, frame_per_second={len(vidhoi_evaluator.gts)/model_inference_time :.2f}") + return coco_evaluator + +########## configs ########## +ex = sacred.Experiment('train') +ex.add_config('cfgs/train.yaml') +ex.add_named_config('vidhoi', 'cfgs/train_vidhoi.yaml') +ex.add_named_config('sttran', 'cfgs/train_sttran.yaml') + +@ex.main +def load_config(_config, _run): + """ We use sacred only for config loading from YAML files. """ + sacred.commands.print_config(_run) + +config = ex.run_commandline().config +args = nested_dict_to_namespace(config) +args.num_classes = 1+78 +args.object_detector = 'frcnn' + +########## environmental settings ########## +utils.init_distributed_mode(args) +print("git:\n {}\n".format(utils.get_sha())) + +output_dir = Path(args.output_dir) +if args.output_dir: + output_dir.mkdir(parents=True, exist_ok=True) + yaml.dump(vars(args), open(output_dir / 'config.yaml', 'w'), allow_unicode=True) +device = torch.device(args.device) + +seed = args.seed + utils.get_rank() +os.environ['PYTHONHASHSEED'] = str(seed) +np.random.seed(seed) +random.seed(seed) + +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +# torch.backends.cudnn.deterministic = True + +########## dataset ########## +dataset_train = build_dataset(split='train', args=args) +dataset_val = build_dataset(split='val', args=args) + +if args.distributed: + sampler_train = utils.DistributedWeightedSampler(dataset_train) + # sampler_train = DistributedSampler(dataset_train) + sampler_val = DistributedSampler(dataset_val, shuffle=False) +else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + +batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, args.batch_size, drop_last=True) + +data_loader_train = DataLoader( + dataset_train, + batch_sampler=batch_sampler_train, + collate_fn=utils.sttran_collate_fn, + num_workers=args.num_workers) +data_loader_val = DataLoader( + dataset_val, args.batch_size, + sampler=sampler_val, + drop_last=False, + collate_fn=utils.sttran_collate_fn, + num_workers=args.num_workers) + +########## model ########## +model = build_sttran(args, obj_classes=[x['name'] for x in dataset_train.coco.dataset['categories']]) +model.to(device) + +model_without_ddp = model +if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True) + model_without_ddp = model.module + +n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) +print('NUM TOTAL MODEL PARAMS:', sum(p.numel() for p in model.parameters())) +print('NUM TRAINABLE MODEL PARAMS:', n_parameters) + +params = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) + +lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [args.lr_drop]) + +if args.resume: + print(f"Resume from model: {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + resume_state_dict = checkpoint['model'] + if 'frcnn' in args.resume: + model_without_ddp.frcnn.load_state_dict(resume_state_dict, strict=False) # only load detector part + else: + model_without_ddp.load_state_dict(resume_state_dict, strict=False) + +########## train & eval ########## +if args.eval_only: + evaluate(model, data_loader_val, device=device) +else: + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs + 1): + if args.distributed: + sampler_train.set_epoch(epoch) + + train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=10) + lr_scheduler.step() + if args.output_dir: + checkpoint = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "args": args, + "epoch": epoch, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, f'checkpoint_epoch{epoch}.pth')) + + # evaluate after every epoch + evaluate(model, data_loader_val, device=device) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str))