ViT ViT 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) def forward(self, img): ######## # 1. 将图片做Patchification处理,并处理成类似文本序列(B,S,H)= (b, h*w, p1*p2*c)维度的序列化数据 ######## x = self.to_patch_embedding(img) b, n, _ = x.shape ######## # 2. 拼接可学习的1维pos embedding ######## x += self.pos_embedding[:, :(n + 1)] ######## # 3. 输入到Transformer,产出最终的序列表征 ######## x = self.transformer(x) Author houmin Publish January 1, 0001 LastMod March 2, 2026 License CC BY-NC-ND 4.0 Linked Mentions No backlinks found.