2
0

segmentation.py 816 B

12345678910111213141516171819202122
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class BCELoss(nn.Module):
  4. def forward(self, prediction, target):
  5. loss = F.binary_cross_entropy_with_logits(prediction,target)
  6. return loss, {}
  7. class BCELossWithQuant(nn.Module):
  8. def __init__(self, codebook_weight=1.):
  9. super().__init__()
  10. self.codebook_weight = codebook_weight
  11. def forward(self, qloss, target, prediction, split):
  12. bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
  13. loss = bce_loss + self.codebook_weight*qloss
  14. return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
  15. "{}/bce_loss".format(split): bce_loss.detach().mean(),
  16. "{}/quant_loss".format(split): qloss.detach().mean()
  17. }