12345678910111213141516171819202122 |
- import torch.nn as nn
- import torch.nn.functional as F
- class BCELoss(nn.Module):
- def forward(self, prediction, target):
- loss = F.binary_cross_entropy_with_logits(prediction,target)
- return loss, {}
- class BCELossWithQuant(nn.Module):
- def __init__(self, codebook_weight=1.):
- super().__init__()
- self.codebook_weight = codebook_weight
- def forward(self, qloss, target, prediction, split):
- bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
- loss = bce_loss + self.codebook_weight*qloss
- return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
- "{}/bce_loss".format(split): bce_loss.detach().mean(),
- "{}/quant_loss".format(split): qloss.detach().mean()
- }
|