Skip to content

Commit 169f229

Browse files
committed
add setr
1 parent 30e0c84 commit 169f229

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

vision_transformers/setr.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
PyTorch implementation of Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective
3+
with Transformers
4+
5+
As described in https://arxiv.org/abs/2012.15840
6+
7+
We deploy a pure transformer (i.e., without convolution and
8+
resolution reduction) to encode an image as a sequence of
9+
patches. With the global context modeled in every layer of
10+
the transformer, this encoder can be combined with a simple
11+
decoder to provide a powerful segmentation model, termed
12+
SEgmentation TRansformer (SETR).
13+
"""
14+
15+
16+
17+
import torch
18+
from torch import nn
19+
import torch.nn.functional as F
20+
21+
class MLP(nn.Module):
22+
def __init__(self, in_features, hidden_features=None, out_features=None):
23+
super().__init__()
24+
hidden_features = hidden_features or in_features
25+
out_features = out_features or in_features
26+
self.fc1 = nn.Linear(in_features, hidden_features)
27+
self.act = nn.GELU()
28+
self.fc2 = nn.Linear(hidden_features, out_features)
29+
30+
def forward(self, x):
31+
x = self.fc1(x)
32+
x = self.act(x)
33+
x = self.fc2(x)
34+
return x
35+
36+
class PatchEmbedding(nn.Module):
37+
def __init__(self, image_size=224, in_channel=3, patch_size=16, embed_dim=768):
38+
super().__init__()
39+
grid_size = image_size // patch_size
40+
self.num_patches = grid_size ** 2
41+
self.proj = nn.Conv2d(in_channel, embed_dim, kernel_size=patch_size,
42+
stride=patch_size)
43+
44+
def forward(self, x):
45+
x = self.proj(x)
46+
B, C, H, W = x.shape
47+
x = x.flatten(2).transpose(1, 2)
48+
return x, (H, W)
49+
50+
class Attention(nn.Module):
51+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0, proj_drop=0):
52+
super().__init__()
53+
assert dim % num_heads == 0
54+
head_dim = dim // num_heads
55+
self.scale = head_dim ** -0.5
56+
self.num_heads= num_heads
57+
self.qkv = nn.Linear(dim, 3 * dim, qkv_bias)
58+
self.attn_drop = nn.Dropout(attn_drop)
59+
self.proj = nn.Linear(dim, dim)
60+
self.proj_drop = nn.Dropout(proj_drop)
61+
62+
def forward(self, x):
63+
B, N, C = x.shape
64+
x = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
65+
q, k, v = x.unbind(0)
66+
attn = (q @ k.transpose(-1, -2)) * self.scale
67+
attn = attn.softmax(dim=-1)
68+
attn = self.attn_drop(attn)
69+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
70+
x = self.proj(x)
71+
x = self.proj_drop(x)
72+
return x
73+
74+
class Block(nn.Module):
75+
def __init__(self, dim, num_heads=8, qkv_bias=False, mlp_ratio=4,
76+
attn_drop=0, proj_drop=0):
77+
super().__init__()
78+
self.norm1 = nn.LayerNorm(dim)
79+
self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, proj_drop)
80+
self.norm2 = nn.LayerNorm(dim)
81+
self.mlp = MLP(dim, dim * mlp_ratio)
82+
83+
def forward(self, x):
84+
x = x + self.attn(self.norm1(x))
85+
x = x + self.mlp(self.norm2(x))
86+
return x
87+
88+
class TransformerEncoder(nn.Module):
89+
def __init__(self, image_size=224, in_channel=3, patch_size=16, embed_dim=768,
90+
num_heads=8, mlp_ratio=4, qkv_bias=False, depths=24, attn_drop=0, proj_drop=0):
91+
super().__init__()
92+
self.patch_size = patch_size
93+
self.embed_dim = embed_dim
94+
self.patch_embedding = PatchEmbedding(image_size, in_channel,
95+
patch_size, embed_dim)
96+
num_patches = self.patch_embedding.num_patches
97+
self.position_embedding = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
98+
self.blocks = nn.Sequential(*[Block(embed_dim, num_heads, qkv_bias,
99+
mlp_ratio, attn_drop, proj_drop) for i in range(depths)])
100+
101+
def forward(self, x):
102+
B = x.shape[0]
103+
x, (H, W) = self.patch_embedding(x)
104+
x += self.position_embedding
105+
x = self.blocks(x)
106+
x = x.reshape(B, -1, H, W)
107+
return x
108+
109+
class SETR(nn.Module):
110+
def __init__(self, image_size=480, in_channel=3, patch_size=16, embed_dim=1024,
111+
num_heads=8, mlp_ratio=4, qkv_bias=False, depths=12, attn_drop=0,
112+
proj_drop=0, num_classes=19):
113+
super().__init__()
114+
self.encoder = TransformerEncoder(image_size, in_channel, patch_size, embed_dim,
115+
num_heads, mlp_ratio, qkv_bias, depths, attn_drop, proj_drop)
116+
self.decoder = nn.Sequential(
117+
nn.Conv2d(embed_dim, num_classes, 1),
118+
nn.BatchNorm2d(num_classes),
119+
nn.ReLU(),
120+
nn.Conv2d(num_classes, num_classes, 1)
121+
)
122+
123+
def forward(self, x):
124+
B, C, H, W = x.shape
125+
x = self.encoder(x)
126+
x = self.decoder(x)
127+
x = F.interpolate(x, size=H, mode="bilinear", align_corners=True)
128+
return x
129+
130+
x = torch.randn(2, 3, 480, 480)
131+
model = SETR(num_classes=19)
132+
y = model(x)
133+
print(y.shape)

0 commit comments

Comments
 (0)