vqgan.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import torch
  2. import torch.nn.functional as F
  3. import pytorch_lightning as pl
  4. from main import instantiate_from_config
  5. from taming.modules.diffusionmodules.model import Encoder, Decoder
  6. from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
  7. from taming.modules.vqvae.quantize import GumbelQuantize
  8. from taming.modules.vqvae.quantize import EMAVectorQuantizer
  9. class VQModel(pl.LightningModule):
  10. def __init__(self,
  11. ddconfig,
  12. lossconfig,
  13. n_embed,
  14. embed_dim,
  15. ckpt_path=None,
  16. ignore_keys=[],
  17. image_key="image",
  18. colorize_nlabels=None,
  19. monitor=None,
  20. remap=None,
  21. sane_index_shape=False, # tell vector quantizer to return indices as bhw
  22. ):
  23. super().__init__()
  24. self.image_key = image_key
  25. self.encoder = Encoder(**ddconfig)
  26. self.decoder = Decoder(**ddconfig)
  27. self.loss = instantiate_from_config(lossconfig)
  28. self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
  29. remap=remap, sane_index_shape=sane_index_shape)
  30. self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
  31. self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
  32. if ckpt_path is not None:
  33. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
  34. self.image_key = image_key
  35. if colorize_nlabels is not None:
  36. assert type(colorize_nlabels)==int
  37. self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
  38. if monitor is not None:
  39. self.monitor = monitor
  40. def init_from_ckpt(self, path, ignore_keys=list()):
  41. sd = torch.load(path, map_location="cpu")["state_dict"]
  42. keys = list(sd.keys())
  43. for k in keys:
  44. for ik in ignore_keys:
  45. if k.startswith(ik):
  46. print("Deleting key {} from state_dict.".format(k))
  47. del sd[k]
  48. self.load_state_dict(sd, strict=False)
  49. print(f"Restored from {path}")
  50. def encode(self, x):
  51. h = self.encoder(x)
  52. h = self.quant_conv(h)
  53. quant, emb_loss, info = self.quantize(h)
  54. return quant, emb_loss, info
  55. def decode(self, quant):
  56. quant = self.post_quant_conv(quant)
  57. dec = self.decoder(quant)
  58. return dec
  59. def decode_code(self, code_b):
  60. quant_b = self.quantize.embed_code(code_b)
  61. dec = self.decode(quant_b)
  62. return dec
  63. def forward(self, input):
  64. quant, diff, _ = self.encode(input)
  65. dec = self.decode(quant)
  66. return dec, diff
  67. def get_input(self, batch, k):
  68. x = batch[k]
  69. if len(x.shape) == 3:
  70. x = x[..., None]
  71. x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
  72. return x.float()
  73. def training_step(self, batch, batch_idx, optimizer_idx):
  74. x = self.get_input(batch, self.image_key)
  75. xrec, qloss = self(x)
  76. if optimizer_idx == 0:
  77. # autoencode
  78. aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  79. last_layer=self.get_last_layer(), split="train")
  80. self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
  81. self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  82. return aeloss
  83. if optimizer_idx == 1:
  84. # discriminator
  85. discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  86. last_layer=self.get_last_layer(), split="train")
  87. self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
  88. self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  89. return discloss
  90. def validation_step(self, batch, batch_idx):
  91. x = self.get_input(batch, self.image_key)
  92. xrec, qloss = self(x)
  93. aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
  94. last_layer=self.get_last_layer(), split="val")
  95. discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
  96. last_layer=self.get_last_layer(), split="val")
  97. rec_loss = log_dict_ae["val/rec_loss"]
  98. self.log("val/rec_loss", rec_loss,
  99. prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
  100. self.log("val/aeloss", aeloss,
  101. prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
  102. self.log_dict(log_dict_ae)
  103. self.log_dict(log_dict_disc)
  104. return self.log_dict
  105. def configure_optimizers(self):
  106. lr = self.learning_rate
  107. opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
  108. list(self.decoder.parameters())+
  109. list(self.quantize.parameters())+
  110. list(self.quant_conv.parameters())+
  111. list(self.post_quant_conv.parameters()),
  112. lr=lr, betas=(0.5, 0.9))
  113. opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
  114. lr=lr, betas=(0.5, 0.9))
  115. return [opt_ae, opt_disc], []
  116. def get_last_layer(self):
  117. return self.decoder.conv_out.weight
  118. def log_images(self, batch, **kwargs):
  119. log = dict()
  120. x = self.get_input(batch, self.image_key)
  121. x = x.to(self.device)
  122. xrec, _ = self(x)
  123. if x.shape[1] > 3:
  124. # colorize with random projection
  125. assert xrec.shape[1] > 3
  126. x = self.to_rgb(x)
  127. xrec = self.to_rgb(xrec)
  128. log["inputs"] = x
  129. log["reconstructions"] = xrec
  130. return log
  131. def to_rgb(self, x):
  132. assert self.image_key == "segmentation"
  133. if not hasattr(self, "colorize"):
  134. self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
  135. x = F.conv2d(x, weight=self.colorize)
  136. x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
  137. return x
  138. class VQSegmentationModel(VQModel):
  139. def __init__(self, n_labels, *args, **kwargs):
  140. super().__init__(*args, **kwargs)
  141. self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
  142. def configure_optimizers(self):
  143. lr = self.learning_rate
  144. opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
  145. list(self.decoder.parameters())+
  146. list(self.quantize.parameters())+
  147. list(self.quant_conv.parameters())+
  148. list(self.post_quant_conv.parameters()),
  149. lr=lr, betas=(0.5, 0.9))
  150. return opt_ae
  151. def training_step(self, batch, batch_idx):
  152. x = self.get_input(batch, self.image_key)
  153. xrec, qloss = self(x)
  154. aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
  155. self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  156. return aeloss
  157. def validation_step(self, batch, batch_idx):
  158. x = self.get_input(batch, self.image_key)
  159. xrec, qloss = self(x)
  160. aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
  161. self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  162. total_loss = log_dict_ae["val/total_loss"]
  163. self.log("val/total_loss", total_loss,
  164. prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
  165. return aeloss
  166. @torch.no_grad()
  167. def log_images(self, batch, **kwargs):
  168. log = dict()
  169. x = self.get_input(batch, self.image_key)
  170. x = x.to(self.device)
  171. xrec, _ = self(x)
  172. if x.shape[1] > 3:
  173. # colorize with random projection
  174. assert xrec.shape[1] > 3
  175. # convert logits to indices
  176. xrec = torch.argmax(xrec, dim=1, keepdim=True)
  177. xrec = F.one_hot(xrec, num_classes=x.shape[1])
  178. xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
  179. x = self.to_rgb(x)
  180. xrec = self.to_rgb(xrec)
  181. log["inputs"] = x
  182. log["reconstructions"] = xrec
  183. return log
  184. class VQNoDiscModel(VQModel):
  185. def __init__(self,
  186. ddconfig,
  187. lossconfig,
  188. n_embed,
  189. embed_dim,
  190. ckpt_path=None,
  191. ignore_keys=[],
  192. image_key="image",
  193. colorize_nlabels=None
  194. ):
  195. super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
  196. ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
  197. colorize_nlabels=colorize_nlabels)
  198. def training_step(self, batch, batch_idx):
  199. x = self.get_input(batch, self.image_key)
  200. xrec, qloss = self(x)
  201. # autoencode
  202. aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
  203. output = pl.TrainResult(minimize=aeloss)
  204. output.log("train/aeloss", aeloss,
  205. prog_bar=True, logger=True, on_step=True, on_epoch=True)
  206. output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  207. return output
  208. def validation_step(self, batch, batch_idx):
  209. x = self.get_input(batch, self.image_key)
  210. xrec, qloss = self(x)
  211. aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
  212. rec_loss = log_dict_ae["val/rec_loss"]
  213. output = pl.EvalResult(checkpoint_on=rec_loss)
  214. output.log("val/rec_loss", rec_loss,
  215. prog_bar=True, logger=True, on_step=True, on_epoch=True)
  216. output.log("val/aeloss", aeloss,
  217. prog_bar=True, logger=True, on_step=True, on_epoch=True)
  218. output.log_dict(log_dict_ae)
  219. return output
  220. def configure_optimizers(self):
  221. optimizer = torch.optim.Adam(list(self.encoder.parameters())+
  222. list(self.decoder.parameters())+
  223. list(self.quantize.parameters())+
  224. list(self.quant_conv.parameters())+
  225. list(self.post_quant_conv.parameters()),
  226. lr=self.learning_rate, betas=(0.5, 0.9))
  227. return optimizer
  228. class GumbelVQ(VQModel):
  229. def __init__(self,
  230. ddconfig,
  231. lossconfig,
  232. n_embed,
  233. embed_dim,
  234. temperature_scheduler_config,
  235. ckpt_path=None,
  236. ignore_keys=[],
  237. image_key="image",
  238. colorize_nlabels=None,
  239. monitor=None,
  240. kl_weight=1e-8,
  241. remap=None,
  242. ):
  243. z_channels = ddconfig["z_channels"]
  244. super().__init__(ddconfig,
  245. lossconfig,
  246. n_embed,
  247. embed_dim,
  248. ckpt_path=None,
  249. ignore_keys=ignore_keys,
  250. image_key=image_key,
  251. colorize_nlabels=colorize_nlabels,
  252. monitor=monitor,
  253. )
  254. self.loss.n_classes = n_embed
  255. self.vocab_size = n_embed
  256. self.quantize = GumbelQuantize(z_channels, embed_dim,
  257. n_embed=n_embed,
  258. kl_weight=kl_weight, temp_init=1.0,
  259. remap=remap)
  260. self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
  261. if ckpt_path is not None:
  262. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
  263. def temperature_scheduling(self):
  264. self.quantize.temperature = self.temperature_scheduler(self.global_step)
  265. def encode_to_prequant(self, x):
  266. h = self.encoder(x)
  267. h = self.quant_conv(h)
  268. return h
  269. def decode_code(self, code_b):
  270. raise NotImplementedError
  271. def training_step(self, batch, batch_idx, optimizer_idx):
  272. self.temperature_scheduling()
  273. x = self.get_input(batch, self.image_key)
  274. xrec, qloss = self(x)
  275. if optimizer_idx == 0:
  276. # autoencode
  277. aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  278. last_layer=self.get_last_layer(), split="train")
  279. self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  280. self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  281. return aeloss
  282. if optimizer_idx == 1:
  283. # discriminator
  284. discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  285. last_layer=self.get_last_layer(), split="train")
  286. self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  287. return discloss
  288. def validation_step(self, batch, batch_idx):
  289. x = self.get_input(batch, self.image_key)
  290. xrec, qloss = self(x, return_pred_indices=True)
  291. aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
  292. last_layer=self.get_last_layer(), split="val")
  293. discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
  294. last_layer=self.get_last_layer(), split="val")
  295. rec_loss = log_dict_ae["val/rec_loss"]
  296. self.log("val/rec_loss", rec_loss,
  297. prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
  298. self.log("val/aeloss", aeloss,
  299. prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
  300. self.log_dict(log_dict_ae)
  301. self.log_dict(log_dict_disc)
  302. return self.log_dict
  303. def log_images(self, batch, **kwargs):
  304. log = dict()
  305. x = self.get_input(batch, self.image_key)
  306. x = x.to(self.device)
  307. # encode
  308. h = self.encoder(x)
  309. h = self.quant_conv(h)
  310. quant, _, _ = self.quantize(h)
  311. # decode
  312. x_rec = self.decode(quant)
  313. log["inputs"] = x
  314. log["reconstructions"] = x_rec
  315. return log
  316. class EMAVQ(VQModel):
  317. def __init__(self,
  318. ddconfig,
  319. lossconfig,
  320. n_embed,
  321. embed_dim,
  322. ckpt_path=None,
  323. ignore_keys=[],
  324. image_key="image",
  325. colorize_nlabels=None,
  326. monitor=None,
  327. remap=None,
  328. sane_index_shape=False, # tell vector quantizer to return indices as bhw
  329. ):
  330. super().__init__(ddconfig,
  331. lossconfig,
  332. n_embed,
  333. embed_dim,
  334. ckpt_path=None,
  335. ignore_keys=ignore_keys,
  336. image_key=image_key,
  337. colorize_nlabels=colorize_nlabels,
  338. monitor=monitor,
  339. )
  340. self.quantize = EMAVectorQuantizer(n_embed=n_embed,
  341. embedding_dim=embed_dim,
  342. beta=0.25,
  343. remap=remap)
  344. def configure_optimizers(self):
  345. lr = self.learning_rate
  346. #Remove self.quantize from parameter list since it is updated via EMA
  347. opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
  348. list(self.decoder.parameters())+
  349. list(self.quant_conv.parameters())+
  350. list(self.post_quant_conv.parameters()),
  351. lr=lr, betas=(0.5, 0.9))
  352. opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
  353. lr=lr, betas=(0.5, 0.9))
  354. return [opt_ae, opt_disc], []