Skip to content

Commit c43e945

Browse files
authored
Add files via upload
1 parent f391345 commit c43e945

10 files changed

+1588
-0
lines changed

distorted/42_2 copy.png

2.21 MB
Loading

distorted/63_2 copy.png

2.76 MB
Loading

extractor.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class ResidualBlock(nn.Module):
7+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8+
super(ResidualBlock, self).__init__()
9+
10+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12+
self.relu = nn.ReLU(inplace=True)
13+
14+
num_groups = planes // 8
15+
16+
if norm_fn == 'group':
17+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19+
if not stride == 1:
20+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21+
22+
elif norm_fn == 'batch':
23+
self.norm1 = nn.BatchNorm2d(planes)
24+
self.norm2 = nn.BatchNorm2d(planes)
25+
if not stride == 1:
26+
self.norm3 = nn.BatchNorm2d(planes)
27+
28+
elif norm_fn == 'instance':
29+
self.norm1 = nn.InstanceNorm2d(planes)
30+
self.norm2 = nn.InstanceNorm2d(planes)
31+
if not stride == 1:
32+
self.norm3 = nn.InstanceNorm2d(planes)
33+
34+
elif norm_fn == 'none':
35+
self.norm1 = nn.Sequential()
36+
self.norm2 = nn.Sequential()
37+
if not stride == 1:
38+
self.norm3 = nn.Sequential()
39+
40+
if stride == 1:
41+
self.downsample = None
42+
43+
else:
44+
self.downsample = nn.Sequential(
45+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46+
47+
48+
def forward(self, x):
49+
y = x
50+
y = self.relu(self.norm1(self.conv1(y)))
51+
y = self.relu(self.norm2(self.conv2(y)))
52+
53+
if self.downsample is not None:
54+
x = self.downsample(x)
55+
56+
return self.relu(x+y)
57+
58+
59+
class BasicEncoder(nn.Module):
60+
def __init__(self, input_dim=128, output_dim=128, norm_fn='batch'):
61+
super(BasicEncoder, self).__init__()
62+
self.norm_fn = norm_fn
63+
64+
if self.norm_fn == 'group':
65+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
66+
67+
elif self.norm_fn == 'batch':
68+
self.norm1 = nn.BatchNorm2d(64)
69+
70+
elif self.norm_fn == 'instance':
71+
self.norm1 = nn.InstanceNorm2d(64)
72+
73+
elif self.norm_fn == 'none':
74+
self.norm1 = nn.Sequential()
75+
76+
self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3)
77+
self.relu1 = nn.ReLU(inplace=True)
78+
79+
self.in_planes = 64
80+
self.layer1 = self._make_layer(64, stride=1)
81+
self.layer2 = self._make_layer(128, stride=2)
82+
self.layer3 = self._make_layer(192, stride=2)
83+
84+
# output convolution
85+
self.conv2 = nn.Conv2d(192, output_dim, kernel_size=1)
86+
87+
for m in self.modules():
88+
if isinstance(m, nn.Conv2d):
89+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
90+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
91+
if m.weight is not None:
92+
nn.init.constant_(m.weight, 1)
93+
if m.bias is not None:
94+
nn.init.constant_(m.bias, 0)
95+
96+
def _make_layer(self, dim, stride=1):
97+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
98+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
99+
layers = (layer1, layer2)
100+
101+
self.in_planes = dim
102+
return nn.Sequential(*layers)
103+
104+
def forward(self, x):
105+
x = self.conv1(x)
106+
x = self.norm1(x)
107+
x = self.relu1(x)
108+
109+
x = self.layer1(x)
110+
x = self.layer2(x)
111+
x = self.layer3(x)
112+
113+
x = self.conv2(x)
114+
115+
return x

inference.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from model import DocGeoNet
2+
from seg import U2NETP
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import skimage.io as io
8+
import numpy as np
9+
import cv2
10+
import os
11+
from PIL import Image
12+
import argparse
13+
import warnings
14+
warnings.filterwarnings('ignore')
15+
16+
17+
class Net(nn.Module):
18+
def __init__(self, opt):
19+
super(Net, self).__init__()
20+
self.msk = U2NETP(3, 1)
21+
self.DocTr = DocGeoNet()
22+
23+
def forward(self, x):
24+
msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
25+
msk = (msk > 0.5).float()
26+
x = msk * x
27+
28+
_, _, bm = self.DocTr(x)
29+
bm = (2 * (bm / 255.) - 1) * 0.99
30+
31+
return bm
32+
33+
34+
def reload_seg_model(model, path=""):
35+
if not bool(path):
36+
return model
37+
else:
38+
model_dict = model.state_dict()
39+
pretrained_dict = torch.load(path, map_location='cuda:0')
40+
print(len(pretrained_dict.keys()))
41+
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
42+
print(len(pretrained_dict.keys()))
43+
model_dict.update(pretrained_dict)
44+
model.load_state_dict(model_dict)
45+
46+
return model
47+
48+
49+
def reload_rec_model(model, path=""):
50+
if not bool(path):
51+
return model
52+
else:
53+
model_dict = model.state_dict()
54+
pretrained_dict = torch.load(path, map_location='cuda:0')
55+
print(len(pretrained_dict.keys()))
56+
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
57+
print(len(pretrained_dict.keys()))
58+
model_dict.update(pretrained_dict)
59+
model.load_state_dict(model_dict)
60+
61+
return model
62+
63+
64+
def rec(seg_model_path, rec_model_path, distorrted_path, save_path, opt):
65+
print(torch.__version__)
66+
67+
# distorted images list
68+
img_list = sorted(os.listdir(distorrted_path))
69+
70+
# creat save path for rectified images
71+
if not os.path.exists(save_path):
72+
os.makedirs(save_path)
73+
74+
net = Net(opt).cuda()
75+
print(get_parameter_number(net))
76+
77+
# reload rec model
78+
reload_rec_model(net.DocTr, rec_model_path)
79+
reload_seg_model(net.msk, opt.seg_model_path)
80+
81+
net.eval()
82+
83+
for img_path in img_list:
84+
name = img_path.split('.')[-2] # image name
85+
img_path = distorrted_path + img_path # image path
86+
87+
im_ori = np.array(Image.open(img_path))[:, :, :3] / 255. # read image 0-255 to 0-1
88+
h, w, _ = im_ori.shape
89+
im = cv2.resize(im_ori, (256, 256))
90+
im = im.transpose(2, 0, 1)
91+
im = torch.from_numpy(im).float().unsqueeze(0)
92+
93+
with torch.no_grad():
94+
bm = net(im.cuda())
95+
bm = bm.cpu()
96+
97+
# save rectified image
98+
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
99+
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
100+
bm0 = cv2.blur(bm0, (3, 3))
101+
bm1 = cv2.blur(bm1, (3, 3))
102+
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
103+
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
104+
cv2.imwrite(save_path + name + '_rec' + '.png', ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8))
105+
106+
107+
def get_parameter_number(net):
108+
total_num = sum(p.numel() for p in net.parameters())
109+
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
110+
return {'Total': total_num, 'Trainable': trainable_num}
111+
112+
113+
def main():
114+
parser = argparse.ArgumentParser()
115+
parser.add_argument('--seg_model_path', default='./model_pretrained/preprocess.pth')
116+
parser.add_argument('--rec_model_path', default='./model_pretrained/DocGeoNet.pth')
117+
parser.add_argument('--distorrted_path', default='./distorted/')
118+
parser.add_argument('--save_path', default='./rec/')
119+
opt = parser.parse_args()
120+
121+
rec(seg_model_path=opt.seg_model_path,
122+
rec_model_path=opt.rec_model_path,
123+
distorrted_path=opt.distorrted_path,
124+
save_path=opt.save_path,
125+
opt=opt)
126+
127+
if __name__ == "__main__":
128+
main()

0 commit comments

Comments
 (0)