Skip to content

🦖Pytorch implementation of popular Attention Mechanisms, Vision Transformers, MLP-Like models and CNNs.🔥🔥🔥

License

Notifications You must be signed in to change notification settings

changzy00/pytorch-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

8ce59da · Jul 6, 2023
Jul 4, 2023
Jul 4, 2023
Jul 6, 2023
Jul 4, 2023
Jul 6, 2023
Jun 20, 2023
Jul 6, 2023

Repository files navigation

image

This codebase is a PyTorch implementation of various attention mechanisms, CNNs, Vision Transformers and MLP-Like models.

If it is helpful for your work, please⭐

Updating...

Install

git clone https://github.com/changzy00/pytorch-attention.git
cd pytorch-attention

Content

Attention Mechanisms

1. Squeeze-and-Excitation Attention

  • Squeeze-and-Excitation Networks (CVPR 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.se_module import SELayer

x = torch.randn(2, 64, 32, 32)
attn = SELayer(64)
y = attn(x)
print(y.shape)

2. Convolutional Block Attention Module

  • CBAM: convolutional block attention module (ECCV 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.cbam import CBAM

x = torch.randn(2, 64, 32, 32)
attn = CBAM(64)
y = attn(x)
print(y.shape)

3. Bottleneck Attention Module

  • Bam: Bottleneck attention module(BMVC 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.bam import BAM

x = torch.randn(2, 64, 32, 32)
attn = BAM(64)
y = attn(x)
print(y.shape)

4. Double Attention

  • A2-nets: Double attention networks (NeurIPS 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.double_attention import DoubleAttention

x = torch.randn(2, 64, 32, 32)
attn = DoubleAttention(64, 32, 32)
y = attn(x)
print(y.shape)

5. Style Attention

  • Srm : A style-based recalibration module for convolutional neural networks (ICCV 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.srm import SRM

x = torch.randn(2, 64, 32, 32)
attn = SRM(64)
y = attn(x)
print(y.shape)

6. Global Context Attention

  • Gcnet: Non-local networks meet squeeze-excitation networks and beyond (ICCVW 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.gc_module import GCModule

x = torch.randn(2, 64, 32, 32)
attn = GCModule(64)
y = attn(x)
print(y.shape)

7. Selective Kernel Attention

  • Selective Kernel Networks (CVPR 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.sk_module import SKLayer

x = torch.randn(2, 64, 32, 32)
attn = SKLayer(64)
y = attn(x)
print(y.shape)

8. Linear Context Attention

  • Linear Context Transform Block (AAAI 2020) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.lct import LCT

x = torch.randn(2, 64, 32, 32)
attn = LCT(64, groups=8)
y = attn(x)
print(y.shape)

9. Gated Channel Attention

  • Gated Channel Transformation for Visual Recognition (CVPR 2020) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.gate_channel_module import GCT

x = torch.randn(2, 64, 32, 32)
attn = GCT(64)
y = attn(x)
print(y.shape)

10. Efficient Channel Attention

  • Ecanet: Efficient channel attention for deep convolutional neural networks (CVPR 2020) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.eca import ECALayer

x = torch.randn(2, 64, 32, 32)
attn = ECALayer(64)
y = attn(x)
print(y.shape)

11. Triplet Attention

  • Rotate to Attend: Convolutional Triplet Attention Module (WACV 2021) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.triplet_attention import TripletAttention

x = torch.randn(2, 64, 32, 32)
attn = TripletAttention(64)
y = attn(x)
print(y.shape)

12. Gaussian Context Attention

  • Gaussian Context Transformer (CVPR 2021) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.gct import GCT

x = torch.randn(2, 64, 32, 32)
attn = GCT(64)
y = attn(x)
print(y.shape)

13. Coordinate Attention

  • Coordinate Attention for Efficient Mobile Network Design (CVPR 2021) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.coordatten import CoordinateAttention

x = torch.randn(2, 64, 32, 32)
attn = CoordinateAttention(64, 64)
y = attn(x)
print(y.shape)

14. SimAM

  • SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks (ICML 2021) pdf
  • Model Overview

  • Code
import torch
from attention_mechanisms.simam import simam_module

x = torch.randn(2, 64, 32, 32)
attn = simam_module(64)
y = attn(x)
print(y.shape)

15. Dual Attention

  • Dual Attention Network for Scene Segmentatio (CVPR 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.dual_attention import PAM, CAM

x = torch.randn(2, 64, 32, 32)
#attn = PAM(64)
attn = CAM()
y = attn(x)
print(y.shape

Vision Transformers

1. ViT Model

  • An image is worth 16x16 words: Transformers for image recognition at scale (ICLR 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.ViT import VisionTransformer

x = torch.randn(2, 3, 224, 224)
model = VisionTransformer()
y = model(x)
print(y.shape) #[2, 1000]

2. XCiT Model

  • XCiT: Cross-Covariance Image Transformer (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.xcit import xcit_nano_12_p16
x = torch.randn(2, 3, 224, 224)
model = xcit_nano_12_p16()
y = model(x)
print(y.shape)

3. PiT Model

  • Rethinking Spatial Dimensions of Vision Transformers (ICCV 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.pit import pit_ti
x = torch.randn(2, 3, 224, 224)
model = pit_ti()
y = model(x)
print(y.shape)

4. CvT Model

  • CvT: Introducing Convolutions to Vision Transformers (ICCV 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.cvt import cvt_13
x = torch.randn(2, 3, 224, 224)
model = cvt_13()
y = model(x)
print(y.shape)

5. PvT Model

  • Pyramid vision transformer: A versatile backbone for dense prediction without convolutions (ICCV 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.pvt import pvt_t
x = torch.randn(2, 3, 224, 224)
model = pvt_t()
y = model(x)
print(y.shape)

6. CMT Model

  • CMT: Convolutional Neural Networks Meet Vision Transformers (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.cmt import cmt_ti
x = torch.randn(2, 3, 224, 224)
model = cmt_ti()
y = model(x)
print(y.shape)

7. PoolFormer Model

  • MetaFormer is Actually What You Need for Vision (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.poolformer import poolformer_12
x = torch.randn(2, 3, 224, 224)
model = poolformer_12()
y = model(x)
print(y.shape)

8. KVT Model

  • KVT: k-NN Attention for Boosting Vision Transformers (ECCV 2022) pdf

  • Code
import torch
from vision_transformers.kvt import KVT
x = torch.randn(2, 3, 224, 224)
model = KVT()
y = model(x)
print(y.shape)

9. MobileViT Model

  • MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer (ICLR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.mobilevit import mobilevit_s
x = torch.randn(2, 3, 224, 224)
model = mobilevit_s()
y = model(x)
print(y.shape)

10. P2T Model

  • Pyramid Pooling Transformer for Scene Understanding (TPAMI 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.p2t import p2t_tiny
x = torch.randn(2, 3, 224, 224)
model = p2t_tiny()
y = model(x)
print(y.shape)

11. EfficientFormer Model

  • EfficientFormer: Vision Transformers at MobileNet Speed (NeurIPS 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.efficientformer import efficientformer_l1
x = torch.randn(2, 3, 224, 224)
model = efficientformer_l1()
y = model(x)
print(y.shape)

12. ShiftViT Model

  • When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism (AAAI 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.shiftvit import shift_t
x = torch.randn(2, 3, 224, 224)
model = shift_t()
y = model(x)
print(y.shape)

13. CSWin Model

  • CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.cswin import CSWin_64_12211_tiny_224
x = torch.randn(2, 3, 224, 224)
model = CSWin_64_12211_tiny_224()
y = model(x)
print(y.shape)

14. DilateFormer Model

  • DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition (TMM 2023) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.dilateformer import dilateformer_tiny
x = torch.randn(2, 3, 224, 224)
model = dilateformer_tiny()
y = model(x)
print(y.shape)

15. BViT Model

  • BViT: Broad Attention based Vision Transformer (TNNLS 2023) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.bvit import BViT_S
x = torch.randn(2, 3, 224, 224)
model = BViT_S()
y = model(x)
print(y.shape)

16. MOAT Model

  • MOAT: Alternating Mobile Convolution and Attention Brings Strong Vision Models (ICLR 2023) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.moat import moat_0
x = torch.randn(2, 3, 224, 224)
model = moat_0()
y = model(x)
print(y.shape)

17. SegFormer Model

  • SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.moat import SegFormer
x = torch.randn(2, 3, 512, 512)
model = SegFormer(num_classes=50)
y = model(x)
print(y.shape)

18. SETR Model

  • Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers (CVPR 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.setr import SETR
x = torch.randn(2, 3, 480, 480)
model = SETR(num_classes=50)
y = model(x)
print(y.shape)

Convolutional Neural Networks(CNNs)

1. NiN Model

  • Network In Network (ICLR 2014) pdf

  • Model Overview

  • Code
import torch
from cnns.NiN import NiN 
x = torch.randn(2, 3, 224, 224)
model = NiN()
y = model(x)
print(y.shape)

2. ResNet Model

  • Deep Residual Learning for Image Recognition (CVPR 2016) pdf

  • Model Overview

  • Code
import torch
from cnns.resnet import resnet18 
x = torch.randn(2, 3, 224, 224)
model = resnet18()
y = model(x)
print(y.shape)

3. WideResNet Model

  • Wide Residual Networks (BMVC 2016) pdf

  • Model Overview

  • Code
import torch
from cnns.wideresnet import wideresnet
x = torch.randn(2, 3, 224, 224)
model = wideresnet()
y = model(x)
print(y.shape)

4. DenseNet Model

  • Densely Connected Convolutional Networks (CVPR 2017) pdf

  • Model Overview

  • Code
import torch
from cnns.densenet import densenet121
x = torch.randn(2, 3, 224, 224)
model = densenet121()
y = model(x)
print(y.shape)

5. PyramidNet Model

  • Deep Pyramidal Residual Networks (CVPR 2017) pdf

  • Model Overview

  • Code
import torch
from cnns.pyramidnet import pyramidnet18
x = torch.randn(2, 3, 224, 224)
model = densenet121()
y = model(x)
print(y.shape)

6. MobileNetV1 Model

  • MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (CVPR 2017) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenetv1 import MobileNetv1
x = torch.randn(2, 3, 224, 224)
model = MobileNetv1()
y = model(x)
print(y.shape)

7. MobileNetV2 Model

  • MobileNetV2: Inverted Residuals and Linear Bottlenecks (CVPR 2018) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenetv2 import MobileNetv2
x = torch.randn(2, 3, 224, 224)
model = MobileNetv2()
y = model(x)
print(y.shape)

8. MobileNetV3 Model

  • Searching for MobileNetV3 (ICCV 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenetv3 import mobilenetv3_small
x = torch.randn(2, 3, 224, 224)
model = mobilenetv3_small()
y = model(x)
print(y.shape)

9. MnasNet Model

  • MnasNet: Platform-Aware Neural Architecture Search for Mobile (CVPR 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.mnasnet import MnasNet
x = torch.randn(2, 3, 224, 224)
model = MnasNet()
y = model(x)
print(y.shape)

10. EfficientNetV1 Model

  • EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (ICML 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.efficientnet import EfficientNet
x = torch.randn(2, 3, 224, 224)
model = EfficientNet()
y = model(x)
print(y.shape)

11. Res2Net Model

  • Res2Net: A New Multi-scale Backbone Architecture (TPAMI 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.res2net import res2net50
x = torch.randn(2, 3, 224, 224)
model = res2net50()
y = model(x)
print(y.shape)

12. MobileNeXt Model

  • Rethinking Bottleneck Structure for Efficient Mobile Network Design (ECCV 2020) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenext import MobileNeXt
x = torch.randn(2, 3, 224, 224)
model = MobileNeXt()
y = model(x)
print(y.shape)

13. GhostNet Model

  • GhostNet: More Features from Cheap Operations (CVPR 2020) pdf

  • Model Overview

  • Code
import torch
from cnns.ghostnet import ghostnet
x = torch.randn(2, 3, 224, 224)
model = ghostnet()
y = model(x)
print(y.shape)

14. EfficientNetV2 Model

  • EfficientNetV2: Smaller Models and Faster Trainin (ICML 2021) pdf

  • Model Overview

  • Code
import torch
from cnns.efficientnet import EfficientNetV2
x = torch.randn(2, 3, 224, 224)
model = EfficientNetV2()
y = model(x)
print(y.shape)

15. ConvNeXt Model

  • A ConvNet for the 2020s (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from cnns.convnext import convnext_18
x = torch.randn(2, 3, 224, 224)
model = convnext_18()
y = model(x)
print(y.shape)

16. Unet Model

  • U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI 2015) pdf

  • Model Overview

  • Code
import torch
from cnns.unet import Unet
x = torch.randn(2, 3, 512, 512)
model = Unet(10)
y = model(x)
print(y.shape)

17. ESPNet Model

  • ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation (ECCV 2018) pdf

  • Model Overview

  • Code
import torch
from cnns.espnet import ESPNet
x = torch.randn(2, 3, 512, 512)
model = ESPNet(10)
y = model(x)
print(y.shape)

MLP-Like Models

1. MLP-Mixer Model

  • MLP-Mixer: An all-MLP Architecture for Vision (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from mlps.mlp_mixer import MLP_Mixer
x = torch.randn(2, 3, 224, 224)
model = MLP_Mixer()
y = model(x)
print(y.shape)

2. gMLP Model

  • Pay Attention to MLPs (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from mlps.gmlp import gMLP
x = torch.randn(2, 3, 224, 224)
model = gMLP()
y = model(x)
print(y.shape)

3. GFNet Model

  • Global Filter Networks for Image Classification (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from mlps.gfnet import GFNet
x = torch.randn(2, 3, 224, 224)
model = GFNet()
y = model(x)
print(y.shape)

4. sMLP Model

  • Sparse MLP for Image Recognition: Is Self-Attention Really Necessary? (AAAI 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.smlp import sMLPNet
x = torch.randn(2, 3, 224, 224)
model = sMLPNet()
y = model(x)
print(y.shape)

5. DynaMixer Model

  • DynaMixer: A Vision MLP Architecture with Dynamic Mixing (ICML 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.dynamixer import DynaMixer
x = torch.randn(2, 3, 224, 224)
model = DynaMixer()
y = model(x)
print(y.shape)

6. ConvMixer Model

  • Patches Are All You Need? (TMLR 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.convmixer import ConvMixer
x = torch.randn(2, 3, 224, 224)
model = ConvMixer(128, 6)
y = model(x)
print(y.shape)

7. ViP Model

  • Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (TPAMI 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.vip import vip_s7
x = torch.randn(2, 3, 224, 224)
model = vip_s7()
y = model(x)
print(y.shape)

8. CycleMLP Model

  • CycleMLP: A MLP-like Architecture for Dense Prediction (ICLR 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.cyclemlp import CycleMLP_B1
x = torch.randn(2, 3, 224, 224)
model = CycleMLP_B1()
y = model(x)
print(y.shape)

9. Sequencer Model

  • Sequencer: Deep LSTM for Image Classification (NeurIPS 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.sequencer import sequencer_s
x = torch.randn(2, 3, 224, 224)
model = sequencer_s()
y = model(x)
print(y.shape)

10. MobileViG Model

  • MobileViG: Graph-Based Sparse Attention for Mobile Vision Applications (CVPRW 2023) pdf

  • Model Overview

  • Code
import torch
from mlps.mobilevig import mobilevig_s
x = torch.randn(2, 3, 224, 224)
model = mobilevig_s()
y = model(x)
print(y.shape)

About

🦖Pytorch implementation of popular Attention Mechanisms, Vision Transformers, MLP-Like models and CNNs.🔥🔥🔥

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages