|
| 1 | +""" |
| 2 | +PyTorch implementation of Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective |
| 3 | +with Transformers |
| 4 | +
|
| 5 | +As described in https://arxiv.org/abs/2012.15840 |
| 6 | +
|
| 7 | +We deploy a pure transformer (i.e., without convolution and |
| 8 | +resolution reduction) to encode an image as a sequence of |
| 9 | +patches. With the global context modeled in every layer of |
| 10 | +the transformer, this encoder can be combined with a simple |
| 11 | +decoder to provide a powerful segmentation model, termed |
| 12 | +SEgmentation TRansformer (SETR). |
| 13 | +""" |
| 14 | + |
| 15 | + |
| 16 | + |
| 17 | +import torch |
| 18 | +from torch import nn |
| 19 | +import torch.nn.functional as F |
| 20 | + |
| 21 | +class MLP(nn.Module): |
| 22 | + def __init__(self, in_features, hidden_features=None, out_features=None): |
| 23 | + super().__init__() |
| 24 | + hidden_features = hidden_features or in_features |
| 25 | + out_features = out_features or in_features |
| 26 | + self.fc1 = nn.Linear(in_features, hidden_features) |
| 27 | + self.act = nn.GELU() |
| 28 | + self.fc2 = nn.Linear(hidden_features, out_features) |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + x = self.fc1(x) |
| 32 | + x = self.act(x) |
| 33 | + x = self.fc2(x) |
| 34 | + return x |
| 35 | + |
| 36 | +class PatchEmbedding(nn.Module): |
| 37 | + def __init__(self, image_size=224, in_channel=3, patch_size=16, embed_dim=768): |
| 38 | + super().__init__() |
| 39 | + grid_size = image_size // patch_size |
| 40 | + self.num_patches = grid_size ** 2 |
| 41 | + self.proj = nn.Conv2d(in_channel, embed_dim, kernel_size=patch_size, |
| 42 | + stride=patch_size) |
| 43 | + |
| 44 | + def forward(self, x): |
| 45 | + x = self.proj(x) |
| 46 | + B, C, H, W = x.shape |
| 47 | + x = x.flatten(2).transpose(1, 2) |
| 48 | + return x, (H, W) |
| 49 | + |
| 50 | +class Attention(nn.Module): |
| 51 | + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0, proj_drop=0): |
| 52 | + super().__init__() |
| 53 | + assert dim % num_heads == 0 |
| 54 | + head_dim = dim // num_heads |
| 55 | + self.scale = head_dim ** -0.5 |
| 56 | + self.num_heads= num_heads |
| 57 | + self.qkv = nn.Linear(dim, 3 * dim, qkv_bias) |
| 58 | + self.attn_drop = nn.Dropout(attn_drop) |
| 59 | + self.proj = nn.Linear(dim, dim) |
| 60 | + self.proj_drop = nn.Dropout(proj_drop) |
| 61 | + |
| 62 | + def forward(self, x): |
| 63 | + B, N, C = x.shape |
| 64 | + x = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 65 | + q, k, v = x.unbind(0) |
| 66 | + attn = (q @ k.transpose(-1, -2)) * self.scale |
| 67 | + attn = attn.softmax(dim=-1) |
| 68 | + attn = self.attn_drop(attn) |
| 69 | + x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| 70 | + x = self.proj(x) |
| 71 | + x = self.proj_drop(x) |
| 72 | + return x |
| 73 | + |
| 74 | +class Block(nn.Module): |
| 75 | + def __init__(self, dim, num_heads=8, qkv_bias=False, mlp_ratio=4, |
| 76 | + attn_drop=0, proj_drop=0): |
| 77 | + super().__init__() |
| 78 | + self.norm1 = nn.LayerNorm(dim) |
| 79 | + self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, proj_drop) |
| 80 | + self.norm2 = nn.LayerNorm(dim) |
| 81 | + self.mlp = MLP(dim, dim * mlp_ratio) |
| 82 | + |
| 83 | + def forward(self, x): |
| 84 | + x = x + self.attn(self.norm1(x)) |
| 85 | + x = x + self.mlp(self.norm2(x)) |
| 86 | + return x |
| 87 | + |
| 88 | +class TransformerEncoder(nn.Module): |
| 89 | + def __init__(self, image_size=224, in_channel=3, patch_size=16, embed_dim=768, |
| 90 | + num_heads=8, mlp_ratio=4, qkv_bias=False, depths=24, attn_drop=0, proj_drop=0): |
| 91 | + super().__init__() |
| 92 | + self.patch_size = patch_size |
| 93 | + self.embed_dim = embed_dim |
| 94 | + self.patch_embedding = PatchEmbedding(image_size, in_channel, |
| 95 | + patch_size, embed_dim) |
| 96 | + num_patches = self.patch_embedding.num_patches |
| 97 | + self.position_embedding = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) |
| 98 | + self.blocks = nn.Sequential(*[Block(embed_dim, num_heads, qkv_bias, |
| 99 | + mlp_ratio, attn_drop, proj_drop) for i in range(depths)]) |
| 100 | + |
| 101 | + def forward(self, x): |
| 102 | + B = x.shape[0] |
| 103 | + x, (H, W) = self.patch_embedding(x) |
| 104 | + x += self.position_embedding |
| 105 | + x = self.blocks(x) |
| 106 | + x = x.reshape(B, -1, H, W) |
| 107 | + return x |
| 108 | + |
| 109 | +class SETR(nn.Module): |
| 110 | + def __init__(self, image_size=480, in_channel=3, patch_size=16, embed_dim=1024, |
| 111 | + num_heads=8, mlp_ratio=4, qkv_bias=False, depths=12, attn_drop=0, |
| 112 | + proj_drop=0, num_classes=19): |
| 113 | + super().__init__() |
| 114 | + self.encoder = TransformerEncoder(image_size, in_channel, patch_size, embed_dim, |
| 115 | + num_heads, mlp_ratio, qkv_bias, depths, attn_drop, proj_drop) |
| 116 | + self.decoder = nn.Sequential( |
| 117 | + nn.Conv2d(embed_dim, num_classes, 1), |
| 118 | + nn.BatchNorm2d(num_classes), |
| 119 | + nn.ReLU(), |
| 120 | + nn.Conv2d(num_classes, num_classes, 1) |
| 121 | + ) |
| 122 | + |
| 123 | + def forward(self, x): |
| 124 | + B, C, H, W = x.shape |
| 125 | + x = self.encoder(x) |
| 126 | + x = self.decoder(x) |
| 127 | + x = F.interpolate(x, size=H, mode="bilinear", align_corners=True) |
| 128 | + return x |
| 129 | + |
| 130 | +x = torch.randn(2, 3, 480, 480) |
| 131 | +model = SETR(num_classes=19) |
| 132 | +y = model(x) |
| 133 | +print(y.shape) |
0 commit comments