|
| 1 | +from model import DocGeoNet |
| 2 | +from seg import U2NETP |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | +import skimage.io as io |
| 8 | +import numpy as np |
| 9 | +import cv2 |
| 10 | +import os |
| 11 | +from PIL import Image |
| 12 | +import argparse |
| 13 | +import warnings |
| 14 | +warnings.filterwarnings('ignore') |
| 15 | + |
| 16 | + |
| 17 | +class Net(nn.Module): |
| 18 | + def __init__(self, opt): |
| 19 | + super(Net, self).__init__() |
| 20 | + self.msk = U2NETP(3, 1) |
| 21 | + self.DocTr = DocGeoNet() |
| 22 | + |
| 23 | + def forward(self, x): |
| 24 | + msk, _1,_2,_3,_4,_5,_6 = self.msk(x) |
| 25 | + msk = (msk > 0.5).float() |
| 26 | + x = msk * x |
| 27 | + |
| 28 | + _, _, bm = self.DocTr(x) |
| 29 | + bm = (2 * (bm / 255.) - 1) * 0.99 |
| 30 | + |
| 31 | + return bm |
| 32 | + |
| 33 | + |
| 34 | +def reload_seg_model(model, path=""): |
| 35 | + if not bool(path): |
| 36 | + return model |
| 37 | + else: |
| 38 | + model_dict = model.state_dict() |
| 39 | + pretrained_dict = torch.load(path, map_location='cuda:0') |
| 40 | + print(len(pretrained_dict.keys())) |
| 41 | + pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict} |
| 42 | + print(len(pretrained_dict.keys())) |
| 43 | + model_dict.update(pretrained_dict) |
| 44 | + model.load_state_dict(model_dict) |
| 45 | + |
| 46 | + return model |
| 47 | + |
| 48 | + |
| 49 | +def reload_rec_model(model, path=""): |
| 50 | + if not bool(path): |
| 51 | + return model |
| 52 | + else: |
| 53 | + model_dict = model.state_dict() |
| 54 | + pretrained_dict = torch.load(path, map_location='cuda:0') |
| 55 | + print(len(pretrained_dict.keys())) |
| 56 | + pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} |
| 57 | + print(len(pretrained_dict.keys())) |
| 58 | + model_dict.update(pretrained_dict) |
| 59 | + model.load_state_dict(model_dict) |
| 60 | + |
| 61 | + return model |
| 62 | + |
| 63 | + |
| 64 | +def rec(seg_model_path, rec_model_path, distorrted_path, save_path, opt): |
| 65 | + print(torch.__version__) |
| 66 | + |
| 67 | + # distorted images list |
| 68 | + img_list = sorted(os.listdir(distorrted_path)) |
| 69 | + |
| 70 | + # creat save path for rectified images |
| 71 | + if not os.path.exists(save_path): |
| 72 | + os.makedirs(save_path) |
| 73 | + |
| 74 | + net = Net(opt).cuda() |
| 75 | + print(get_parameter_number(net)) |
| 76 | + |
| 77 | + # reload rec model |
| 78 | + reload_rec_model(net.DocTr, rec_model_path) |
| 79 | + reload_seg_model(net.msk, opt.seg_model_path) |
| 80 | + |
| 81 | + net.eval() |
| 82 | + |
| 83 | + for img_path in img_list: |
| 84 | + name = img_path.split('.')[-2] # image name |
| 85 | + img_path = distorrted_path + img_path # image path |
| 86 | + |
| 87 | + im_ori = np.array(Image.open(img_path))[:, :, :3] / 255. # read image 0-255 to 0-1 |
| 88 | + h, w, _ = im_ori.shape |
| 89 | + im = cv2.resize(im_ori, (256, 256)) |
| 90 | + im = im.transpose(2, 0, 1) |
| 91 | + im = torch.from_numpy(im).float().unsqueeze(0) |
| 92 | + |
| 93 | + with torch.no_grad(): |
| 94 | + bm = net(im.cuda()) |
| 95 | + bm = bm.cpu() |
| 96 | + |
| 97 | + # save rectified image |
| 98 | + bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow |
| 99 | + bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow |
| 100 | + bm0 = cv2.blur(bm0, (3, 3)) |
| 101 | + bm1 = cv2.blur(bm1, (3, 3)) |
| 102 | + lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 |
| 103 | + out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True) |
| 104 | + cv2.imwrite(save_path + name + '_rec' + '.png', ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)) |
| 105 | + |
| 106 | + |
| 107 | +def get_parameter_number(net): |
| 108 | + total_num = sum(p.numel() for p in net.parameters()) |
| 109 | + trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) |
| 110 | + return {'Total': total_num, 'Trainable': trainable_num} |
| 111 | + |
| 112 | + |
| 113 | +def main(): |
| 114 | + parser = argparse.ArgumentParser() |
| 115 | + parser.add_argument('--seg_model_path', default='./model_pretrained/preprocess.pth') |
| 116 | + parser.add_argument('--rec_model_path', default='./model_pretrained/DocGeoNet.pth') |
| 117 | + parser.add_argument('--distorrted_path', default='./distorted/') |
| 118 | + parser.add_argument('--save_path', default='./rec/') |
| 119 | + opt = parser.parse_args() |
| 120 | + |
| 121 | + rec(seg_model_path=opt.seg_model_path, |
| 122 | + rec_model_path=opt.rec_model_path, |
| 123 | + distorrted_path=opt.distorrted_path, |
| 124 | + save_path=opt.save_path, |
| 125 | + opt=opt) |
| 126 | + |
| 127 | +if __name__ == "__main__": |
| 128 | + main() |
0 commit comments