Skip to content

Commit

Permalink
support onnx now
Browse files Browse the repository at this point in the history
  • Loading branch information
teowu committed Jul 31, 2023
1 parent 6af95fb commit cb00806
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 0 deletions.
41 changes: 41 additions & 0 deletions convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn

import dover
from dover.models import VQAHead
from dover.models import VQABackbone as VideoBackbone, convnext_3d_tiny

class MinimumDOVER(nn.Module):
def __init__(self):
super().__init__()
self.technical_backbone = VideoBackbone(use_checkpoint=False)
self.aesthetic_backbone = convnext_3d_tiny(pretrained=False)
self.technical_head = VQAHead(pre_pool=False, in_channels=768)
self.aesthetic_head = VQAHead(pre_pool=False, in_channels=768)


def forward(self,aesthetic_view, technical_view):
self.eval()
with torch.no_grad():
aesthetic_score = self.aesthetic_head(self.aesthetic_backbone(aesthetic_view))
technical_score = self.technical_head(self.technical_backbone(technical_view))

aesthetic_score_pooled = torch.mean(aesthetic_score, (1,2,3,4))
technical_score_pooled = torch.mean(technical_score, (1,2,3,4))
return [aesthetic_score_pooled, technical_score_pooled]

import torch
minimum_dover = MinimumDOVER()
sd = torch.load("pretrained_weights/DOVER.pth", map_location="cpu")
minimum_dover.load_state_dict(sd)

if torch.cuda.is_available():
minimum_dover = minimum_dover.cuda()
dummy_inputs = (torch.randn(1,3,32,224,224).cuda(), torch.randn(4,3,32,224,224).cuda())
else:
dummy_inputs = (torch.randn(1,3,32,224,224), torch.randn(4,3,32,224,224))

torch.onnx.export(minimum_dover, dummy_inputs, "onnx_dover.onnx", verbose=True,
input_names=["aes_view", "tech_view"])

print("Successfull")
139 changes: 139 additions & 0 deletions onnx_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch

import argparse
import pickle as pkl

import decord
import numpy as np
import yaml

import onnxruntime as ort

from dover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition

mean, std = (
torch.FloatTensor([123.675, 116.28, 103.53]),
torch.FloatTensor([58.395, 57.12, 57.375]),
)


# 4-parameter sigmoid rescaling, as adviced by ITU
def fuse_results(results: list):
x = (results[0] - 0.1107) / 0.07355 * 0.6104 + (
results[1] + 0.08285
) / 0.03774 * 0.3896
print(x)
return 1 / (1 + np.exp(-x))


def gaussian_rescale(pr):
# The results should follow N(0,1)
pr = (pr - np.mean(pr)) / np.std(pr)
return pr


def uniform_rescale(pr):
# The result scores should follow U(0,1)
return np.arange(len(pr))[np.argsort(pr).argsort()] / len(pr)


def rescale_results(results: list, vname="undefined"):
dbs = {
"livevqc": "LIVE_VQC",
"kv1k": "KoNViD-1k",
"ltest": "LSVQ_Test",
"l1080p": "LSVQ_1080P",
"ytugc": "YouTube_UGC",
}
for abbr, full_name in dbs.items():
with open(f"dover_predictions/val-{abbr}.pkl", "rb") as f:
pr_labels = pkl.load(f)
aqe_score_set = pr_labels["resize"]
tqe_score_set = pr_labels["fragments"]
tqe_score_set_p = np.concatenate((np.array([results[0]]), tqe_score_set), 0)
aqe_score_set_p = np.concatenate((np.array([results[1]]), aqe_score_set), 0)
tqe_nscore = gaussian_rescale(tqe_score_set_p)[0]
tqe_uscore = uniform_rescale(tqe_score_set_p)[0]
print(f"Compared with all videos in the {full_name} dataset:")
print(
f"-- the technical quality of video [{vname}] is better than {int(tqe_uscore*100)}% of videos, with normalized score {tqe_nscore:.2f}."
)
aqe_nscore = gaussian_rescale(aqe_score_set_p)[0]
aqe_uscore = uniform_rescale(aqe_score_set_p)[0]
print(
f"-- the aesthetic quality of video [{vname}] is better than {int(aqe_uscore*100)}% of videos, with normalized score {aqe_nscore:.2f}."
)


if __name__ == "__main__":

parser = argparse.ArgumentParser()

parser.add_argument(
"-o", "--opt", type=str, default="./dover.yml", help="the option file"
)

## can be your own
parser.add_argument(
"-v",
"--video_path",
type=str,
default="./demo/1724.mp4",
help="the input video path",
)


args = parser.parse_args()

with open(args.opt, "r") as f:
opt = yaml.safe_load(f)


dopt = opt["data"]["val-l1080p"]["args"]

temporal_samplers = {}
for stype, sopt in dopt["sample_types"].items():
if "t_frag" not in sopt:
# resized temporal sampling for TQE in DOVER
temporal_samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"]
)
else:
# temporal sampling for AQE in DOVER
temporal_samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"] // sopt["t_frag"],
sopt["t_frag"],
sopt["frame_interval"],
sopt["num_clips"],
)

### View Decomposition
views, _ = spatial_temporal_view_decomposition(
args.video_path, dopt["sample_types"], temporal_samplers
)

for k, v in views.items():
num_clips = dopt["sample_types"][k].get("num_clips", 1)
views[k] = (
((v.permute(1, 2, 3, 0) - mean) / std)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
)


aes_input = views["aesthetic"]
tech_input = views["technical"]
ort_session = ort.InferenceSession("onnx_dover.onnx")

import time

s = time.time()
predictions = ort_session.run(None, {"aes_view": aes_input.numpy(),
"tech_view": tech_input.numpy()})

scores = [np.mean(s) for s in predictions]
print(f"Inference time cost: {time.time() - s:.3f}s.")
# predict fused overall score, with default score-level fusion parameters
print(f"Normalized fused overall score (scale in [0,1]): {fuse_results(scores):.3f}")

0 comments on commit cb00806

Please sign in to comment.