2
0

quantize.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. from torch import einsum
  6. from einops import rearrange
  7. class VectorQuantizer(nn.Module):
  8. """
  9. see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
  10. ____________________________________________
  11. Discretization bottleneck part of the VQ-VAE.
  12. Inputs:
  13. - n_e : number of embeddings
  14. - e_dim : dimension of embedding
  15. - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
  16. _____________________________________________
  17. """
  18. # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
  19. # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
  20. # used wherever VectorQuantizer has been used before and is additionally
  21. # more efficient.
  22. def __init__(self, n_e, e_dim, beta):
  23. super(VectorQuantizer, self).__init__()
  24. self.n_e = n_e
  25. self.e_dim = e_dim
  26. self.beta = beta
  27. self.embedding = nn.Embedding(self.n_e, self.e_dim)
  28. self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
  29. def forward(self, z):
  30. """
  31. Inputs the output of the encoder network z and maps it to a discrete
  32. one-hot vector that is the index of the closest embedding vector e_j
  33. z (continuous) -> z_q (discrete)
  34. z.shape = (batch, channel, height, width)
  35. quantization pipeline:
  36. 1. get encoder input (B,C,H,W)
  37. 2. flatten input to (B*H*W,C)
  38. """
  39. # reshape z -> (batch, height, width, channel) and flatten
  40. z = z.permute(0, 2, 3, 1).contiguous()
  41. z_flattened = z.view(-1, self.e_dim)
  42. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  43. d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
  44. torch.sum(self.embedding.weight**2, dim=1) - 2 * \
  45. torch.matmul(z_flattened, self.embedding.weight.t())
  46. ## could possible replace this here
  47. # #\start...
  48. # find closest encodings
  49. min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
  50. min_encodings = torch.zeros(
  51. min_encoding_indices.shape[0], self.n_e).to(z)
  52. min_encodings.scatter_(1, min_encoding_indices, 1)
  53. # dtype min encodings: torch.float32
  54. # min_encodings shape: torch.Size([2048, 512])
  55. # min_encoding_indices.shape: torch.Size([2048, 1])
  56. # get quantized latent vectors
  57. z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
  58. #.........\end
  59. # with:
  60. # .........\start
  61. #min_encoding_indices = torch.argmin(d, dim=1)
  62. #z_q = self.embedding(min_encoding_indices)
  63. # ......\end......... (TODO)
  64. # compute loss for embedding
  65. loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
  66. torch.mean((z_q - z.detach()) ** 2)
  67. # preserve gradients
  68. z_q = z + (z_q - z).detach()
  69. # perplexity
  70. e_mean = torch.mean(min_encodings, dim=0)
  71. perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
  72. # reshape back to match original input shape
  73. z_q = z_q.permute(0, 3, 1, 2).contiguous()
  74. return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
  75. def get_codebook_entry(self, indices, shape):
  76. # shape specifying (batch, height, width, channel)
  77. # TODO: check for more easy handling with nn.Embedding
  78. min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
  79. min_encodings.scatter_(1, indices[:,None], 1)
  80. # get quantized latent vectors
  81. z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
  82. if shape is not None:
  83. z_q = z_q.view(shape)
  84. # reshape back to match original input shape
  85. z_q = z_q.permute(0, 3, 1, 2).contiguous()
  86. return z_q
  87. class GumbelQuantize(nn.Module):
  88. """
  89. credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
  90. Gumbel Softmax trick quantizer
  91. Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
  92. https://arxiv.org/abs/1611.01144
  93. """
  94. def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
  95. kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
  96. remap=None, unknown_index="random"):
  97. super().__init__()
  98. self.embedding_dim = embedding_dim
  99. self.n_embed = n_embed
  100. self.straight_through = straight_through
  101. self.temperature = temp_init
  102. self.kl_weight = kl_weight
  103. self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
  104. self.embed = nn.Embedding(n_embed, embedding_dim)
  105. self.use_vqinterface = use_vqinterface
  106. self.remap = remap
  107. if self.remap is not None:
  108. self.register_buffer("used", torch.tensor(np.load(self.remap)))
  109. self.re_embed = self.used.shape[0]
  110. self.unknown_index = unknown_index # "random" or "extra" or integer
  111. if self.unknown_index == "extra":
  112. self.unknown_index = self.re_embed
  113. self.re_embed = self.re_embed+1
  114. print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
  115. f"Using {self.unknown_index} for unknown indices.")
  116. else:
  117. self.re_embed = n_embed
  118. def remap_to_used(self, inds):
  119. ishape = inds.shape
  120. assert len(ishape)>1
  121. inds = inds.reshape(ishape[0],-1)
  122. used = self.used.to(inds)
  123. match = (inds[:,:,None]==used[None,None,...]).long()
  124. new = match.argmax(-1)
  125. unknown = match.sum(2)<1
  126. if self.unknown_index == "random":
  127. new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
  128. else:
  129. new[unknown] = self.unknown_index
  130. return new.reshape(ishape)
  131. def unmap_to_all(self, inds):
  132. ishape = inds.shape
  133. assert len(ishape)>1
  134. inds = inds.reshape(ishape[0],-1)
  135. used = self.used.to(inds)
  136. if self.re_embed > self.used.shape[0]: # extra token
  137. inds[inds>=self.used.shape[0]] = 0 # simply set to zero
  138. back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
  139. return back.reshape(ishape)
  140. def forward(self, z, temp=None, return_logits=False):
  141. # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
  142. hard = self.straight_through if self.training else True
  143. temp = self.temperature if temp is None else temp
  144. logits = self.proj(z)
  145. if self.remap is not None:
  146. # continue only with used logits
  147. full_zeros = torch.zeros_like(logits)
  148. logits = logits[:,self.used,...]
  149. soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
  150. if self.remap is not None:
  151. # go back to all entries but unused set to zero
  152. full_zeros[:,self.used,...] = soft_one_hot
  153. soft_one_hot = full_zeros
  154. z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
  155. # + kl divergence to the prior loss
  156. qy = F.softmax(logits, dim=1)
  157. diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
  158. ind = soft_one_hot.argmax(dim=1)
  159. if self.remap is not None:
  160. ind = self.remap_to_used(ind)
  161. if self.use_vqinterface:
  162. if return_logits:
  163. return z_q, diff, (None, None, ind), logits
  164. return z_q, diff, (None, None, ind)
  165. return z_q, diff, ind
  166. def get_codebook_entry(self, indices, shape):
  167. b, h, w, c = shape
  168. assert b*h*w == indices.shape[0]
  169. indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
  170. if self.remap is not None:
  171. indices = self.unmap_to_all(indices)
  172. one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
  173. z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
  174. return z_q
  175. class VectorQuantizer2(nn.Module):
  176. """
  177. Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
  178. avoids costly matrix multiplications and allows for post-hoc remapping of indices.
  179. """
  180. # NOTE: due to a bug the beta term was applied to the wrong term. for
  181. # backwards compatibility we use the buggy version by default, but you can
  182. # specify legacy=False to fix it.
  183. def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
  184. sane_index_shape=False, legacy=True):
  185. super().__init__()
  186. self.n_e = n_e
  187. self.e_dim = e_dim
  188. self.beta = beta
  189. self.legacy = legacy
  190. self.embedding = nn.Embedding(self.n_e, self.e_dim)
  191. self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
  192. self.remap = remap
  193. if self.remap is not None:
  194. self.register_buffer("used", torch.tensor(np.load(self.remap)))
  195. self.re_embed = self.used.shape[0]
  196. self.unknown_index = unknown_index # "random" or "extra" or integer
  197. if self.unknown_index == "extra":
  198. self.unknown_index = self.re_embed
  199. self.re_embed = self.re_embed+1
  200. print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
  201. f"Using {self.unknown_index} for unknown indices.")
  202. else:
  203. self.re_embed = n_e
  204. self.sane_index_shape = sane_index_shape
  205. def remap_to_used(self, inds):
  206. ishape = inds.shape
  207. assert len(ishape)>1
  208. inds = inds.reshape(ishape[0],-1)
  209. used = self.used.to(inds)
  210. match = (inds[:,:,None]==used[None,None,...]).long()
  211. new = match.argmax(-1)
  212. unknown = match.sum(2)<1
  213. if self.unknown_index == "random":
  214. new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
  215. else:
  216. new[unknown] = self.unknown_index
  217. return new.reshape(ishape)
  218. def unmap_to_all(self, inds):
  219. ishape = inds.shape
  220. assert len(ishape)>1
  221. inds = inds.reshape(ishape[0],-1)
  222. used = self.used.to(inds)
  223. if self.re_embed > self.used.shape[0]: # extra token
  224. inds[inds>=self.used.shape[0]] = 0 # simply set to zero
  225. back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
  226. return back.reshape(ishape)
  227. def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
  228. assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
  229. assert rescale_logits==False, "Only for interface compatible with Gumbel"
  230. assert return_logits==False, "Only for interface compatible with Gumbel"
  231. # reshape z -> (batch, height, width, channel) and flatten
  232. z = rearrange(z, 'b c h w -> b h w c').contiguous()
  233. z_flattened = z.view(-1, self.e_dim)
  234. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  235. d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
  236. torch.sum(self.embedding.weight**2, dim=1) - 2 * \
  237. torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
  238. min_encoding_indices = torch.argmin(d, dim=1)
  239. z_q = self.embedding(min_encoding_indices).view(z.shape)
  240. perplexity = None
  241. min_encodings = None
  242. # compute loss for embedding
  243. if not self.legacy:
  244. loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
  245. torch.mean((z_q - z.detach()) ** 2)
  246. else:
  247. loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
  248. torch.mean((z_q - z.detach()) ** 2)
  249. # preserve gradients
  250. z_q = z + (z_q - z).detach()
  251. # reshape back to match original input shape
  252. z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
  253. if self.remap is not None:
  254. min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
  255. min_encoding_indices = self.remap_to_used(min_encoding_indices)
  256. min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
  257. if self.sane_index_shape:
  258. min_encoding_indices = min_encoding_indices.reshape(
  259. z_q.shape[0], z_q.shape[2], z_q.shape[3])
  260. return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
  261. def get_codebook_entry(self, indices, shape):
  262. # shape specifying (batch, height, width, channel)
  263. if self.remap is not None:
  264. indices = indices.reshape(shape[0],-1) # add batch axis
  265. indices = self.unmap_to_all(indices)
  266. indices = indices.reshape(-1) # flatten again
  267. # get quantized latent vectors
  268. z_q = self.embedding(indices)
  269. if shape is not None:
  270. z_q = z_q.view(shape)
  271. # reshape back to match original input shape
  272. z_q = z_q.permute(0, 3, 1, 2).contiguous()
  273. return z_q
  274. class EmbeddingEMA(nn.Module):
  275. def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
  276. super().__init__()
  277. self.decay = decay
  278. self.eps = eps
  279. weight = torch.randn(num_tokens, codebook_dim)
  280. self.weight = nn.Parameter(weight, requires_grad = False)
  281. self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
  282. self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
  283. self.update = True
  284. def forward(self, embed_id):
  285. return F.embedding(embed_id, self.weight)
  286. def cluster_size_ema_update(self, new_cluster_size):
  287. self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
  288. def embed_avg_ema_update(self, new_embed_avg):
  289. self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
  290. def weight_update(self, num_tokens):
  291. n = self.cluster_size.sum()
  292. smoothed_cluster_size = (
  293. (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
  294. )
  295. #normalize embedding average with smoothed cluster size
  296. embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
  297. self.weight.data.copy_(embed_normalized)
  298. class EMAVectorQuantizer(nn.Module):
  299. def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
  300. remap=None, unknown_index="random"):
  301. super().__init__()
  302. self.codebook_dim = codebook_dim
  303. self.num_tokens = num_tokens
  304. self.beta = beta
  305. self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
  306. self.remap = remap
  307. if self.remap is not None:
  308. self.register_buffer("used", torch.tensor(np.load(self.remap)))
  309. self.re_embed = self.used.shape[0]
  310. self.unknown_index = unknown_index # "random" or "extra" or integer
  311. if self.unknown_index == "extra":
  312. self.unknown_index = self.re_embed
  313. self.re_embed = self.re_embed+1
  314. print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
  315. f"Using {self.unknown_index} for unknown indices.")
  316. else:
  317. self.re_embed = n_embed
  318. def remap_to_used(self, inds):
  319. ishape = inds.shape
  320. assert len(ishape)>1
  321. inds = inds.reshape(ishape[0],-1)
  322. used = self.used.to(inds)
  323. match = (inds[:,:,None]==used[None,None,...]).long()
  324. new = match.argmax(-1)
  325. unknown = match.sum(2)<1
  326. if self.unknown_index == "random":
  327. new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
  328. else:
  329. new[unknown] = self.unknown_index
  330. return new.reshape(ishape)
  331. def unmap_to_all(self, inds):
  332. ishape = inds.shape
  333. assert len(ishape)>1
  334. inds = inds.reshape(ishape[0],-1)
  335. used = self.used.to(inds)
  336. if self.re_embed > self.used.shape[0]: # extra token
  337. inds[inds>=self.used.shape[0]] = 0 # simply set to zero
  338. back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
  339. return back.reshape(ishape)
  340. def forward(self, z):
  341. # reshape z -> (batch, height, width, channel) and flatten
  342. #z, 'b c h w -> b h w c'
  343. z = rearrange(z, 'b c h w -> b h w c')
  344. z_flattened = z.reshape(-1, self.codebook_dim)
  345. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  346. d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
  347. self.embedding.weight.pow(2).sum(dim=1) - 2 * \
  348. torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
  349. encoding_indices = torch.argmin(d, dim=1)
  350. z_q = self.embedding(encoding_indices).view(z.shape)
  351. encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
  352. avg_probs = torch.mean(encodings, dim=0)
  353. perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
  354. if self.training and self.embedding.update:
  355. #EMA cluster size
  356. encodings_sum = encodings.sum(0)
  357. self.embedding.cluster_size_ema_update(encodings_sum)
  358. #EMA embedding average
  359. embed_sum = encodings.transpose(0,1) @ z_flattened
  360. self.embedding.embed_avg_ema_update(embed_sum)
  361. #normalize embed_avg and update weight
  362. self.embedding.weight_update(self.num_tokens)
  363. # compute loss for embedding
  364. loss = self.beta * F.mse_loss(z_q.detach(), z)
  365. # preserve gradients
  366. z_q = z + (z_q - z).detach()
  367. # reshape back to match original input shape
  368. #z_q, 'b h w c -> b c h w'
  369. z_q = rearrange(z_q, 'b h w c -> b c h w')
  370. return z_q, loss, (perplexity, encodings, encoding_indices)