123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404 |
- import torch
- import torch.nn.functional as F
- import pytorch_lightning as pl
- from main import instantiate_from_config
- from taming.modules.diffusionmodules.model import Encoder, Decoder
- from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
- from taming.modules.vqvae.quantize import GumbelQuantize
- from taming.modules.vqvae.quantize import EMAVectorQuantizer
- class VQModel(pl.LightningModule):
- def __init__(self,
- ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- remap=None,
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
- ):
- super().__init__()
- self.image_key = image_key
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
- self.loss = instantiate_from_config(lossconfig)
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
- remap=remap, sane_index_shape=sane_index_shape)
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
- self.image_key = image_key
- if colorize_nlabels is not None:
- assert type(colorize_nlabels)==int
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
- if monitor is not None:
- self.monitor = monitor
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location="cpu")["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- self.load_state_dict(sd, strict=False)
- print(f"Restored from {path}")
- def encode(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- quant, emb_loss, info = self.quantize(h)
- return quant, emb_loss, info
- def decode(self, quant):
- quant = self.post_quant_conv(quant)
- dec = self.decoder(quant)
- return dec
- def decode_code(self, code_b):
- quant_b = self.quantize.embed_code(code_b)
- dec = self.decode(quant_b)
- return dec
- def forward(self, input):
- quant, diff, _ = self.encode(input)
- dec = self.decode(quant)
- return dec, diff
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
- return x.float()
- def training_step(self, batch, batch_idx, optimizer_idx):
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x)
- if optimizer_idx == 0:
- # autoencode
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return aeloss
- if optimizer_idx == 1:
- # discriminator
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return discloss
- def validation_step(self, batch, batch_idx):
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x)
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
- last_layer=self.get_last_layer(), split="val")
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
- last_layer=self.get_last_layer(), split="val")
- rec_loss = log_dict_ae["val/rec_loss"]
- self.log("val/rec_loss", rec_loss,
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
- self.log("val/aeloss", aeloss,
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
- self.log_dict(log_dict_ae)
- self.log_dict(log_dict_disc)
- return self.log_dict
- def configure_optimizers(self):
- lr = self.learning_rate
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quantize.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=lr, betas=(0.5, 0.9))
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
- lr=lr, betas=(0.5, 0.9))
- return [opt_ae, opt_disc], []
- def get_last_layer(self):
- return self.decoder.conv_out.weight
- def log_images(self, batch, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- xrec, _ = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log["inputs"] = x
- log["reconstructions"] = xrec
- return log
- def to_rgb(self, x):
- assert self.image_key == "segmentation"
- if not hasattr(self, "colorize"):
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
- x = F.conv2d(x, weight=self.colorize)
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
- return x
- class VQSegmentationModel(VQModel):
- def __init__(self, n_labels, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
- def configure_optimizers(self):
- lr = self.learning_rate
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quantize.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=lr, betas=(0.5, 0.9))
- return opt_ae
- def training_step(self, batch, batch_idx):
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x)
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return aeloss
- def validation_step(self, batch, batch_idx):
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x)
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- total_loss = log_dict_ae["val/total_loss"]
- self.log("val/total_loss", total_loss,
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
- return aeloss
- @torch.no_grad()
- def log_images(self, batch, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- xrec, _ = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- # convert logits to indices
- xrec = torch.argmax(xrec, dim=1, keepdim=True)
- xrec = F.one_hot(xrec, num_classes=x.shape[1])
- xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log["inputs"] = x
- log["reconstructions"] = xrec
- return log
- class VQNoDiscModel(VQModel):
- def __init__(self,
- ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None
- ):
- super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
- ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
- colorize_nlabels=colorize_nlabels)
- def training_step(self, batch, batch_idx):
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x)
- # autoencode
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
- output = pl.TrainResult(minimize=aeloss)
- output.log("train/aeloss", aeloss,
- prog_bar=True, logger=True, on_step=True, on_epoch=True)
- output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return output
- def validation_step(self, batch, batch_idx):
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x)
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
- rec_loss = log_dict_ae["val/rec_loss"]
- output = pl.EvalResult(checkpoint_on=rec_loss)
- output.log("val/rec_loss", rec_loss,
- prog_bar=True, logger=True, on_step=True, on_epoch=True)
- output.log("val/aeloss", aeloss,
- prog_bar=True, logger=True, on_step=True, on_epoch=True)
- output.log_dict(log_dict_ae)
- return output
- def configure_optimizers(self):
- optimizer = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quantize.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=self.learning_rate, betas=(0.5, 0.9))
- return optimizer
- class GumbelVQ(VQModel):
- def __init__(self,
- ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- temperature_scheduler_config,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- kl_weight=1e-8,
- remap=None,
- ):
- z_channels = ddconfig["z_channels"]
- super().__init__(ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- ckpt_path=None,
- ignore_keys=ignore_keys,
- image_key=image_key,
- colorize_nlabels=colorize_nlabels,
- monitor=monitor,
- )
- self.loss.n_classes = n_embed
- self.vocab_size = n_embed
- self.quantize = GumbelQuantize(z_channels, embed_dim,
- n_embed=n_embed,
- kl_weight=kl_weight, temp_init=1.0,
- remap=remap)
- self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
- def temperature_scheduling(self):
- self.quantize.temperature = self.temperature_scheduler(self.global_step)
- def encode_to_prequant(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- return h
- def decode_code(self, code_b):
- raise NotImplementedError
- def training_step(self, batch, batch_idx, optimizer_idx):
- self.temperature_scheduling()
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x)
- if optimizer_idx == 0:
- # autoencode
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return aeloss
- if optimizer_idx == 1:
- # discriminator
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return discloss
- def validation_step(self, batch, batch_idx):
- x = self.get_input(batch, self.image_key)
- xrec, qloss = self(x, return_pred_indices=True)
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
- last_layer=self.get_last_layer(), split="val")
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
- last_layer=self.get_last_layer(), split="val")
- rec_loss = log_dict_ae["val/rec_loss"]
- self.log("val/rec_loss", rec_loss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
- self.log("val/aeloss", aeloss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
- self.log_dict(log_dict_ae)
- self.log_dict(log_dict_disc)
- return self.log_dict
- def log_images(self, batch, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- # encode
- h = self.encoder(x)
- h = self.quant_conv(h)
- quant, _, _ = self.quantize(h)
- # decode
- x_rec = self.decode(quant)
- log["inputs"] = x
- log["reconstructions"] = x_rec
- return log
- class EMAVQ(VQModel):
- def __init__(self,
- ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- remap=None,
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
- ):
- super().__init__(ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- ckpt_path=None,
- ignore_keys=ignore_keys,
- image_key=image_key,
- colorize_nlabels=colorize_nlabels,
- monitor=monitor,
- )
- self.quantize = EMAVectorQuantizer(n_embed=n_embed,
- embedding_dim=embed_dim,
- beta=0.25,
- remap=remap)
- def configure_optimizers(self):
- lr = self.learning_rate
- #Remove self.quantize from parameter list since it is updated via EMA
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=lr, betas=(0.5, 0.9))
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
- lr=lr, betas=(0.5, 0.9))
- return [opt_ae, opt_disc], []
|