-
Notifications
You must be signed in to change notification settings - Fork 77
/
export.py
62 lines (49 loc) · 2.22 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import unittest
import cv2
from modelscope.exporters.cv import CartoonTranslationExporter
from modelscope.msdatasets import MsDataset
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.trainers.cv import CartoonTranslationTrainer
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class TestImagePortraitStylizationTrainer(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.image_portrait_stylization
self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
data_dir = MsDataset.load(
'dctnet_train_clipart_mini_ms',
namespace='menyifang',
split='train').config_kwargs['split_config']['train']
data_photo = os.path.join(data_dir, 'face_photo')
data_cartoon = os.path.join(data_dir, 'face_cartoon')
work_dir = 'exp_localtoon'
max_steps = 10
trainer = CartoonTranslationTrainer(
model=model_id,
work_dir=work_dir,
photo=data_photo,
cartoon=data_cartoon,
max_steps=max_steps)
trainer.train()
# export pb file
ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0))
pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb')
exporter = CartoonTranslationExporter()
exporter.export_frozen_graph_def(
ckpt_path=ckpt_path, frozen_graph_path=pb_path)
# infer with pb file
self.pipeline_person_image_cartoon(trainer.model_dir)
def pipeline_person_image_cartoon(self, model_dir):
pipeline_cartoon = pipeline(task=self.task, model=model_dir)
result = pipeline_cartoon(input=self.test_image)
if result is not None:
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {os.path.abspath("result.png")}')
if __name__ == '__main__':
unittest.main()