论文阅读-BEITv2

Posted by 高庆东 on May 7, 2022

模型流程

1、训练图像编码器

常见的Image token 有三种方式

grid feature

这种就是取卷积后的特征图,每个点就是个token

region feature

这个比较简单就是 目标检测的结果,框出来之后的特征作为token

patch feature

直接切图片然后提取特征

VQ-VAE

注意

视觉任务中一般用相对位置编码

论文提出的观点和使用的方法

1、blockwise masking
2、相对位置编码
3、向量化知识蒸馏VQ-KD
4、EmbeddingEMA 滑动平均避免只有少部分Embedding有效
5、滑动

实现过程

1、 找一个CLIP提取图像特征需要保存每个token除了cls token
2、 再通过CLIP中的768变512矩阵吧每个token都变化一下得到target VQ-KD 教师信号 3、 vit 提取特征经过码本变化也就是下面代码实现编码
4、 vit解码器解码和教师信号做cos距离损失,获得VQ-KD模型 5、 训练vit:vit提取图像特征,然后VQ提取码本特征得到码本index
6、 使用CEloss得到vit

’’’

def l2norm(t): return F.normalize(t, p = 2, dim = -1) def ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) def sample_vectors(samples, num): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device = device)[:num] else: indices = torch.randint(0, num_samples, (num,), device = device) return samples[indices]

def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): dim, dtype, device = samples.shape[-1], samples.dtype, samples.device

means = sample_vectors(samples, num_clusters)

for _ in range(num_iters):
    if use_cosine_sim:
        dists = samples @ means.t()
    else:
        diffs = rearrange(samples, 'n d -> n () d') \
                - rearrange(means, 'c d -> () c d')
        dists = -(diffs ** 2).sum(dim = -1)

    buckets = dists.max(dim = -1).indices
    bins = torch.bincount(buckets, minlength = num_clusters)
    zero_mask = bins == 0
    bins_min_clamped = bins.masked_fill(zero_mask, 1)

    new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype)
    new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples)
    new_means = new_means / bins_min_clamped[..., None]

    if use_cosine_sim:
        new_means = l2norm(new_means)

    means = torch.where(zero_mask[..., None], means, new_means)

return means, bins

class EmbeddingEMA(nn.Module): def init(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=’’): super().init() self.num_tokens = num_tokens self.codebook_dim = codebook_dim self.decay = decay self.eps = eps if codebook_init_path == ‘’:
if not kmeans_init: weight = torch.randn(num_tokens, codebook_dim) weight = l2norm(weight) else: weight = torch.zeros(num_tokens, codebook_dim) self.register_buffer(‘initted’, torch.Tensor([not kmeans_init])) else: print(f”load init codebook weight from {codebook_init_path}”) codebook_ckpt_weight = torch.load(codebook_init_path, map_location=’cpu’) weight = codebook_ckpt_weight.clone() self.register_buffer(‘initted’, torch.Tensor([True]))

    self.weight = nn.Parameter(weight, requires_grad = False)
    self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
    self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
    # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
    self.update = True

@torch.jit.ignore
def init_embed_(self, data):
    if self.initted:
        return
    print("Performing Kemans init for codebook")
    embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim = True)
    self.weight.data.copy_(embed)
    self.cluster_size.data.copy_(cluster_size)
    self.initted.data.copy_(torch.Tensor([True]))
    
def forward(self, embed_id):
    return F.embedding(embed_id, self.weight)

def cluster_size_ema_update(self, new_cluster_size):
    self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)

def embed_avg_ema_update(self, new_embed_avg): 
    self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)

def weight_update(self, num_tokens):
    n = self.cluster_size.sum()
    smoothed_cluster_size = (
            (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
        )
    #normalize embedding average with smoothed cluster size
    embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
    # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
    self.weight.data.copy_(embed_normalized)   

def norm_ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) moving_avg.data.copy_(l2norm(moving_avg.data))

class NormEMAVectorQuantizer(nn.Module): def init(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, statistic_code_usage=True, kmeans_init=False, codebook_init_path=’’): super().init() self.codebook_dim = embedding_dim#32 self.num_tokens = n_embed#8192 self.beta = beta#1 self.decay = decay#0.99

    # learnable = True if orthogonal_reg_weight > 0 else False
    self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
    
    self.statistic_code_usage = statistic_code_usage
    if statistic_code_usage:
        self.register_buffer('cluster_size', torch.zeros(n_embed))
    if distributed.is_available() and distributed.is_initialized():
        print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
        self.all_reduce_fn = distributed.all_reduce
    else:
        self.all_reduce_fn = nn.Identity()

def reset_cluster_size(self, device):
    if self.statistic_code_usage:
        self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
        self.cluster_size = self.cluster_size.to(device)

def forward(self, z):
    # reshape z -> (batch, height, width, channel) and flatten
    #z, 'b c h w -> b h w c'
    z = rearrange(z, 'b c h w -> b h w c')
    z = l2norm(z)
    z_flattened = z.reshape(-1, self.codebook_dim)
    
    self.embedding.init_embed_(z_flattened)
    
    d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
        self.embedding.weight.pow(2).sum(dim=1) - 2 * \
        torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
    
    encoding_indices = torch.argmin(d, dim=1)

    z_q = self.embedding(encoding_indices).view(z.shape)
    
    encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)     
    
    if not self.training:
        with torch.no_grad():
            cluster_size = encodings.sum(0)
            self.all_reduce_fn(cluster_size)
            ema_inplace(self.cluster_size, cluster_size, self.decay)
    
    if self.training and self.embedding.update:
        #EMA cluster size

        bins = encodings.sum(0)
        self.all_reduce_fn(bins)

        # self.embedding.cluster_size_ema_update(bins)
        ema_inplace(self.cluster_size, bins, self.decay)

        zero_mask = (bins == 0)
        bins = bins.masked_fill(zero_mask, 1.)

        embed_sum = z_flattened.t() @ encodings
        self.all_reduce_fn(embed_sum)
                    
        embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
        embed_normalized = l2norm(embed_normalized)
        
        embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
                                       embed_normalized)
        norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)

    # compute loss for embedding
    loss = self.beta * F.mse_loss(z_q.detach(), z) 
    
    # preserve gradients
    z_q = z + (z_q - z).detach()

    # reshape back to match original input shape
    #z_q, 'b h w c -> b c h w'
    z_q = rearrange(z_q, 'b h w c -> b c h w')
    return z_q, loss, encoding_indices

’’’

MAE对比
MAE中掩码的比率非常高,达到 75%。相对的,在 BERT 中,
对文本数据的掩码率为 15%。这体现出图像数据的冗余性和文本数据的高度语义性

训练细节

训练 tokenizer 时,由于中间的最近邻查表操作是不可微的,为了梯度反传,
可将 decoder 输入的梯度直接拷贝到 encoder 输出。因为 quantizer
查找的是每个编码器输出的最近邻 embedding,码本 embedding 的梯度可
以为编码器指示合理的优化方向;为了稳定码本的训练并提高利用率,避免码本坍塌,
导致只有一小部分 embedding 会被使用,tokenizer 的训练采用了一些 trick。
其中包括使用标准化 l2 距离、降低 embedding 维度到 32 维、滑动指数平均 (EMA);