模型
变分自编码器,存在一个分布这个分布可以通过某种变化,变为训练集上的数据
变分意识就是变化分布
是一种全新的图像输入变为token的方式
耿贝尔分布
一个人的心跳一天统计10次,取最大值作为一天的心跳,因为每次心跳是随机值所
以每天心跳也是随机值每天的心跳服从的是耿贝尔分布在模型中argmax、采样都是
不可导的当模型中间有这类函数时,不能进行求导,导致无法更新参数。如果想要求
导可以使用重参数方法,其中耿贝尔分布就是重参数的一种方法。例如:当模型中
输出离散值时,比如码本中的位置,这种信息还需要输入到后续模型,此时这个位置
是无法求导更新参数的。这时需要重参数。在分类任务中虽然输出是softmax也是
离散值,但是这里没有argmax或者采样过程。
https://www.cnblogs.com/initial-h/p/9468974.html
耿贝尔分布使用
耿贝尔噪声如下
对于网络输出的一个𝑛维向量𝑣 ,生成𝑛个服从均匀分布𝑈(0,1)的独立样本𝜖1,…,𝜖𝑛
通过𝐺𝑖=−log(−log(𝜖𝑖)) 计算得到𝐺𝑖
对应相加得到新的值向量𝑣′=[𝑣1+𝐺1,𝑣2+𝐺2,…,𝑣𝑛+𝐺𝑛]
通过softmax函数计算𝑣′向量结果
模型结构
整体上有三块
1 是下采样,通过一些列卷积对图像下采样
2 codebook 这块主要是通过建设一个图像token
3 codebook 生成的图进行重构
下采样是
‘’’
nn.Sequential(nn.Conv2d(3, 256, 4, stride = 2, padding = 1), nn.ReLU())
nn.Sequential(nn.Conv2d(256, 256, 4, stride = 2, padding = 1), nn.ReLU())
nn.Sequential(nn.Resnet(256, 256, 3, stride = 1, padding = 1), nn.ReLU(), nn.Resnet(256, 256, 3, stride = 1, padding = 1)
nn.ReLU(), nn.Resnet(256, 256, 1, stride = 1)
nn.Sequential(nn.Conv2d(256, 8192, 1, stride = 1))
’’’
码表建设 ‘’’ 码表类似文本任务中的文本embedding nn.Embedding(8192, 512) 图像经过下采样后得到B*8129*H*W 数据 logits 是图像下采样的输出,其中gumbel_softmax 是耿贝尔分布,再经过码表计算得到采样数据 soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through) sampled = einsum(‘b n h w, n d -> b d h w’, soft_one_hot, self.codebook.weight) ‘’’
上采样
没啥好说的,得到图像的位置token,经过码表映射成Embedding向量
然后向量经过坐标变换,然后下采样成图