-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencode.py
118 lines (79 loc) · 3.24 KB
/
encode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os, sys
import os.path as osp
os.environ['CUDA_VISIBLE_DEVICES']='0'
root_dir = os.path.abspath(os.getcwd())
sys.path.append(root_dir)
import model
import utils
import glob, yaml
import bchlib
import numpy as np
from PIL import Image, ImageOps
import torch
import cv2
from torch import nn
from easydict import EasyDict
import tqdm
import shutil, csv
torch.set_num_threads(1)
BCH_POLYNOMIAL = 37
BCH_BITS = 5
def get_bit_acc(secret_true, secret_pred):
if 'cuda' in str(secret_pred.device):
secret_pred = secret_pred.cpu()
secret_true = secret_true.cpu()
secret_pred = torch.round(secret_pred)
correct_pred = torch.sum((secret_pred - secret_true) == 0, dim=1)
bit_acc = torch.sum(correct_pred).numpy() / secret_pred.numel()
return bit_acc
def main():
en_checkpoint_path = '{}/checkpoints/encoder.pth'.format(root_dir)
images_dir = '{}/images_without_trace'.format(root_dir)
save_dir = '{}/images_with_strace'.format(root_dir)
yaml_path = '{}/00_setting.yaml'.format(root_dir)
with open(yaml_path, 'r') as f:
yml_args = EasyDict(yaml.load(f, Loader=yaml.SafeLoader))
files_list = sorted(glob.glob(images_dir + '/*'))
encoder = model.StegaStampEncoder(secret_size=yml_args.secret_size, args=yml_args).cuda()
checkpoint = torch.load(en_checkpoint_path)
encoder.load_state_dict(checkpoint['state_dict'])
assert yml_args.secret_size == 40
BCH_POLYNOMIAL = 37
BCH_BITS = 5
info_bits = 8
set_secret_bit = '01010011'
img_cv = cv2.imread('{}/checkpoints/train_mean_face.png'.format(root_dir))
img_torch = torch.from_numpy(img_cv).float()
img_torch = img_torch.permute((2, 0, 1))
img_torch /= 255.
img_torch = img_torch.unsqueeze(0)
set_image_o = img_torch.repeat(1, 1, 1, 1).cuda()
bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
data_binlist = [int(x) for x in set_secret_bit]
data_bytearray = utils.binlist2bytearray(data_binlist)
secret = [int(x) for x in set_secret_bit]
secret_all = np.array(secret).reshape(1, info_bits)
secret_all = torch.from_numpy(secret_all).type(torch.FloatTensor).cuda()
ecc = bch.encode(data_bytearray)
packet = data_bytearray + ecc
packet_binary = ''.join(format(x, '08b') for x in packet)
secret_full = [int(x) for x in packet_binary]
secret_full = np.array(secret_full).reshape(1, yml_args.secret_size)
secret_full = torch.from_numpy(secret_full).type(torch.FloatTensor).cuda()
noise_o = encoder((secret_full, set_image_o))
for filename in tqdm.tqdm(files_list):
img_name = filename.split('/')[-1]
encoder_info = noise_o.clone()
img_cv = cv2.imread(filename)
img_torch = torch.from_numpy(img_cv).float()
img_torch /= 255.
image = img_torch.unsqueeze(0).permute(0, 3, 1, 2).cuda()
encoded_img = image + encoder_info
hidden_img_t = torch.clamp(encoded_img, 0.0, 1.0)
hidden_img = hidden_img_t.permute(0, 2, 3, 1)
rescaled = np.clip((hidden_img[0] * 255.0).cpu().detach().numpy(), 0, 255)
rescaled = rescaled.astype(np.uint8)
save_path = osp.join(save_dir, img_name)
cv2.imwrite(save_path, rescaled)
if __name__ == "__main__":
main()