日韩av黄I国产麻豆传媒I国产91av视频在线观看I日韩一区二区三区在线看I美女国产在线I麻豆视频国产在线观看I成人黄色短片

歡迎訪問 生活随笔!

生活随笔

當(dāng)前位置: 首頁 >

CVAE (条件 变分 自动编码器)

發(fā)布時(shí)間:2023/12/20 43 豆豆
生活随笔 收集整理的這篇文章主要介紹了 CVAE (条件 变分 自动编码器) 小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.

notations

  • xxx image
  • zzz latent
  • yyy label (omitted to lighten notation)
  • p(x∣z)p(x|z)p(xz) decoder Encoder
  • q(z∣x)q(z|x)q(zx) encoder Decoder
  • p^(h)\hat{p}(h)p^?(h) prior encoder (by variational inference) PriorEncoder

model structure

class CVAE(nn.Module):def __init__(self, config):super(CVAE, self).__init__()self.encoder = Encoder(...)self.decoder = Decoder(...)self.priorEncoder = PriorEncoder(...)def forward(self, x, y):x = x.reshape((-1, 784)) # MNISTmu, sigma = self.encoder(x, y)prior_mu, prior_sigma = self.priorEncoder(y)z = torch.randn_like(mu)z = z * sigma + mureconstructed_x = self.decoder(z, y)reconstructed_x = reconstructed_x.reshape((-1, 28, 28))return reconstructed_x, mu, sigma, prior_mu, prior_sigmadef infer(self, y):prior_mu, prior_sigma = self.priorEncoder(y)z = torch.randn_like(prior_mu)z = z * prior_sigma + prior_mureconstructed_x = self.decoder(z, y)return reconstructed_x # class Loss(nn.Module):def __init__(self):super(Loss,self).__init__()self.loss_fn = nn.MSELoss(reduction='mean')self.kld_loss_weight = 1e-5def forward(self, x, reconstructed_x, mu, sigma, prior_mu, prior_sigma):mse_loss = self.loss_fn(x, reconstructed_x)kld_loss = torch.log(prior_sigma / sigma) + (sigma**2 + (mu - prior_mu)**2) / (2 * prior_sigma**2) - 0.5kld_loss = torch.sum(kld_loss) / x.shape[0]loss = mse_loss + self.kld_loss_weight * kld_lossreturn loss # def train(model, criterion, optimizer, data_loader, config):train_task_time_str = time_str()for epoch in range(config.num_epoch):loss_seq = []for step, (x,y) in tqdm(enumerate(data_loader)):# -------------------- data --------------------x = x.to(device)y = y.to(device)# -------------------- forward --------------------reconstructed_x, mu, sigma, prior_mu, prior_sigma = model(x, y)loss = criterion(x, reconstructed_x, mu, sigma, prior_mu, prior_sigma)# -------------------- log --------------------loss_seq.append(loss.item())# -------------------- backward --------------------optimizer.zero_grad()loss.backward()optimizer.step()# -------------------- end --------------------logging.info(f'epoch {epoch:^5d} loss {sum(loss_seq[-config.batch_size:]) / config.batch_size:.5f}')with torch.no_grad():# -------------------- file --------------------path = f'{config.save_fig_path}/{train_task_time_str}' # type(model).__name__if not os.path.exists(path):os.makedirs(path)path += f'/epoch{epoch:04d}.png'# -------------------- figure --------------------plt.close()fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(10, 2), dpi=512)fig.suptitle(f'epoch {epoch} loss {sum(loss_seq[-config.batch_size:]) / config.batch_size:.5f}')# -------------------- infer --------------------y = torch.Tensor(list(range(config.num_class)))y = y.to(dtype=torch.int64)y = nn.functional.one_hot(y, num_classes=config.num_class)y = y.to(dtype=torch.float)y = y.to(device)x = model.infer(y)x = x.cpu()x = x.numpy()x += x.min()x /= x.max()x *= 255x = x.astype(np.uint8)# -------------------- plot --------------------for idx,ax,arr in zip(range(config.num_class),axs,x):ax.set_title(str(idx))ax.axis('off')ax.imshow(arr.reshape((28,28)), cmap='BuGn')# -------------------- save --------------------# plt.show()plt.savefig(path)# -------------------- end -------------------- #

dynamics

  • SGVB (stochastic_gradient + variational_bayesian) 框架根據(jù) EM算法的原理 使用 變分推斷 優(yōu)化 ELBO.
  • log?p(v)=ELBO(q(z∣x),p(x∣z))+KL(q(z∣x)∥p(z∣x))\log p(v) = \mathrm{ELBO} \left( q(z|x), p(x|z) \right) + \mathrm{KL} \left( q(z|x) \| p(z|x) \right)logp(v)=ELBO(q(zx),p(xz))+KL(q(zx)p(zx)) ELBO 是 對(duì)數(shù)似然 的代理.
  • ELBO=Eq(z∣x)[log?p(x∣z)p(z)]+Entropy(q(z∣x))=Eq(z∣x)[log?p(x∣z)]?KL(q(z∣x)∥p(z))\mathrm{ELBO} = \mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z)p(z) \right] + \mathrm{Entropy} \left( q(z|x) \right) = \mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z) \right] - \mathrm{KL} \left( q(z|x) \| p(z) \right)ELBO=q(zx)E?[logp(xz)p(z)]+Entropy(q(zx))=q(zx)E?[logp(xz)]?KL(q(zx)p(z))
    • Eq(z∣x)[log?p(x∣z)p(z)]+Entropy(q(z∣x))\mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z)p(z) \right] + \mathrm{Entropy} \left( q(z|x) \right)q(zx)E?[logp(xz)p(z)]+Entropy(q(zx)) 用于證明EM算法的原理.
    • Eq(z∣x)[log?p(x∣z)]?KL(q(z∣x)∥p(z))\mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z) \right] - \mathrm{KL} \left( q(z|x) \| p(z) \right)q(zx)E?[logp(xz)]?KL(q(zx)p(z)) 用于神經(jīng)網(wǎng)絡(luò)優(yōu)化.
      • max?Eq(z∣x)[log?p(x∣z)]≈samplingmax?q(zi∣xi)log?p(xi∣zi)=oppositemin?cross_entropy_loss~substitutionmin?mse_loss\max \mathop{\mathbb{E}} \limits_{q(z|x)} \left[ \log p(x|z) \right] \stackrel{\textsf{sampling}}{\approx} \max q(z_i|x_i) \log p(x_i|z_i) \stackrel{\textsf{opposite}}{=} \min \mathtt{cross\_entropy\_loss} \stackrel{\textsf{substitution}}{\sim} \min \mathtt{mse\_loss}maxq(zx)E?[logp(xz)]samplingmaxq(zi?xi?)logp(xi?zi?)=oppositemincross_entropy_losssubstitutionminmse_loss
      • min?KL(q(z∣x)∥p(z))≈VariationalInferencemin?KL(q(z∣x)∥p^(z))\min \mathrm{KL} \left( q(z|x) \| p(z) \right) \stackrel{\textsf{Variational Inference}}{\approx} \min \mathrm{KL} \left( q(z|x) \| \hat{p}(z) \right)minKL(q(zx)p(z))Variational?InferenceminKL(q(zx)p^?(z))
















kld_loss_weight = 1e-5 {'batch_size': 25,'conv_encoder': True,'learning_rate': 1e-05,'num_class': 10,'num_epoch': 16,'save_fig_path': './figs','use_cuda': True}

以上這組超參數(shù)能較快的收斂到較優(yōu)模型參數(shù).
實(shí)驗(yàn)發(fā)現(xiàn), batch_size較大時(shí)收斂到較差模型參數(shù), learning_rate較小時(shí)收斂非常緩慢.

  • 神經(jīng)網(wǎng)絡(luò)先學(xué)數(shù)字范圍再學(xué)數(shù)字形狀. epoch[0-3]數(shù)字有很多噪聲點(diǎn), epoch[4-15]數(shù)字呈平滑圖形.
  • 神經(jīng)網(wǎng)絡(luò)先學(xué)前景(數(shù)字)再學(xué)背景(白色). epoch[0-10]背景都是暗色, epoch[11-15]背景都是亮色.
  • epoch11開始學(xué)最不重要的細(xì)節(jié)(白色背景), epoch12開始就逐漸發(fā)生了過擬合! 尤其是數(shù)字0, 在epoch15中看起來像數(shù)字8一樣.

總結(jié)

以上是生活随笔為你收集整理的CVAE (条件 变分 自动编码器)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。