core_vq.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. #
  7. # This implementation is inspired from
  8. # https://github.com/lucidrains/vector-quantize-pytorch
  9. # which is released under MIT License. Hereafter, the original license:
  10. # MIT License
  11. #
  12. # Copyright (c) 2020 Phil Wang
  13. #
  14. # Permission is hereby granted, free of charge, to any person obtaining a copy
  15. # of this software and associated documentation files (the "Software"), to deal
  16. # in the Software without restriction, including without limitation the rights
  17. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  18. # copies of the Software, and to permit persons to whom the Software is
  19. # furnished to do so, subject to the following conditions:
  20. #
  21. # The above copyright notice and this permission notice shall be included in all
  22. # copies or substantial portions of the Software.
  23. #
  24. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  25. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  26. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  27. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  28. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  29. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  30. # SOFTWARE.
  31. """Core vector quantization implementation."""
  32. import typing as tp
  33. from einops import rearrange, repeat
  34. import torch
  35. from torch import nn
  36. import torch.nn.functional as F
  37. from tqdm import tqdm
  38. def default(val: tp.Any, d: tp.Any) -> tp.Any:
  39. return val if val is not None else d
  40. def ema_inplace(moving_avg, new, decay: float):
  41. moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
  42. def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
  43. return (x + epsilon) / (x.sum() + n_categories * epsilon)
  44. def uniform_init(*shape: int):
  45. t = torch.empty(shape)
  46. nn.init.kaiming_uniform_(t)
  47. return t
  48. def sample_vectors(samples, num: int):
  49. num_samples, device = samples.shape[0], samples.device
  50. if num_samples >= num:
  51. indices = torch.randperm(num_samples, device=device)[:num]
  52. else:
  53. indices = torch.randint(0, num_samples, (num,), device=device)
  54. return samples[indices]
  55. def kmeans(samples, num_clusters: int, num_iters: int = 10):
  56. dim, dtype = samples.shape[-1], samples.dtype
  57. max_kmeans_samples = 500
  58. samples = samples[:max_kmeans_samples, :]
  59. means = sample_vectors(samples, num_clusters)
  60. print("kmeans start ... ")
  61. for _ in tqdm(range(num_iters)):
  62. diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
  63. dists = -(diffs**2).sum(dim=-1)
  64. buckets = dists.max(dim=-1).indices
  65. bins = torch.bincount(buckets, minlength=num_clusters)
  66. zero_mask = bins == 0
  67. bins_min_clamped = bins.masked_fill(zero_mask, 1)
  68. new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
  69. new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
  70. new_means = new_means / bins_min_clamped[..., None]
  71. means = torch.where(zero_mask[..., None], means, new_means)
  72. return means, bins
  73. class EuclideanCodebook(nn.Module):
  74. """Codebook with Euclidean distance.
  75. Args:
  76. dim (int): Dimension.
  77. codebook_size (int): Codebook size.
  78. kmeans_init (bool): Whether to use k-means to initialize the codebooks.
  79. If set to true, run the k-means algorithm on the first training batch and use
  80. the learned centroids as initialization.
  81. kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
  82. decay (float): Decay for exponential moving average over the codebooks.
  83. epsilon (float): Epsilon value for numerical stability.
  84. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
  85. that have an exponential moving average cluster size less than the specified threshold with
  86. randomly selected vector from the current batch.
  87. """
  88. def __init__(
  89. self,
  90. dim: int,
  91. codebook_size: int,
  92. kmeans_init: int = False,
  93. kmeans_iters: int = 10,
  94. decay: float = 0.99,
  95. epsilon: float = 1e-5,
  96. threshold_ema_dead_code: int = 2,
  97. ):
  98. super().__init__()
  99. self.decay = decay
  100. init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
  101. uniform_init if not kmeans_init else torch.zeros
  102. )
  103. embed = init_fn(codebook_size, dim)
  104. self.codebook_size = codebook_size
  105. self.kmeans_iters = kmeans_iters
  106. self.epsilon = epsilon
  107. self.threshold_ema_dead_code = threshold_ema_dead_code
  108. self.register_buffer("inited", torch.Tensor([not kmeans_init]))
  109. self.register_buffer("cluster_size", torch.zeros(codebook_size))
  110. self.register_buffer("embed", embed)
  111. self.register_buffer("embed_avg", embed.clone())
  112. @torch.jit.ignore
  113. def init_embed_(self, data):
  114. if self.inited:
  115. return
  116. embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
  117. self.embed.data.copy_(embed)
  118. self.embed_avg.data.copy_(embed.clone())
  119. self.cluster_size.data.copy_(cluster_size)
  120. self.inited.data.copy_(torch.Tensor([True]))
  121. # Make sure all buffers across workers are in sync after initialization
  122. # broadcast_tensors(self.buffers())
  123. def replace_(self, samples, mask):
  124. modified_codebook = torch.where(
  125. mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
  126. )
  127. self.embed.data.copy_(modified_codebook)
  128. def expire_codes_(self, batch_samples):
  129. if self.threshold_ema_dead_code == 0:
  130. return
  131. expired_codes = self.cluster_size < self.threshold_ema_dead_code
  132. if not torch.any(expired_codes):
  133. return
  134. batch_samples = rearrange(batch_samples, "... d -> (...) d")
  135. self.replace_(batch_samples, mask=expired_codes)
  136. # broadcast_tensors(self.buffers())
  137. def preprocess(self, x):
  138. x = rearrange(x, "... d -> (...) d")
  139. return x
  140. def quantize(self, x):
  141. embed = self.embed.t()
  142. dist = -(
  143. x.pow(2).sum(1, keepdim=True)
  144. - 2 * x @ embed
  145. + embed.pow(2).sum(0, keepdim=True)
  146. )
  147. embed_ind = dist.max(dim=-1).indices
  148. return embed_ind
  149. def postprocess_emb(self, embed_ind, shape):
  150. return embed_ind.view(*shape[:-1])
  151. def dequantize(self, embed_ind):
  152. quantize = F.embedding(embed_ind, self.embed)
  153. return quantize
  154. def encode(self, x):
  155. shape = x.shape
  156. # pre-process
  157. x = self.preprocess(x)
  158. # quantize
  159. embed_ind = self.quantize(x)
  160. # post-process
  161. embed_ind = self.postprocess_emb(embed_ind, shape)
  162. return embed_ind
  163. def decode(self, embed_ind):
  164. quantize = self.dequantize(embed_ind)
  165. return quantize
  166. def forward(self, x):
  167. shape, dtype = x.shape, x.dtype
  168. x = self.preprocess(x)
  169. self.init_embed_(x)
  170. embed_ind = self.quantize(x)
  171. embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
  172. embed_ind = self.postprocess_emb(embed_ind, shape)
  173. quantize = self.dequantize(embed_ind)
  174. if self.training:
  175. # We do the expiry of code at that point as buffers are in sync
  176. # and all the workers will take the same decision.
  177. self.expire_codes_(x)
  178. ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
  179. embed_sum = x.t() @ embed_onehot
  180. ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
  181. cluster_size = (
  182. laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
  183. * self.cluster_size.sum()
  184. )
  185. embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
  186. self.embed.data.copy_(embed_normalized)
  187. return quantize, embed_ind
  188. class VectorQuantization(nn.Module):
  189. """Vector quantization implementation.
  190. Currently supports only euclidean distance.
  191. Args:
  192. dim (int): Dimension
  193. codebook_size (int): Codebook size
  194. codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
  195. decay (float): Decay for exponential moving average over the codebooks.
  196. epsilon (float): Epsilon value for numerical stability.
  197. kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
  198. kmeans_iters (int): Number of iterations used for kmeans initialization.
  199. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
  200. that have an exponential moving average cluster size less than the specified threshold with
  201. randomly selected vector from the current batch.
  202. commitment_weight (float): Weight for commitment loss.
  203. """
  204. def __init__(
  205. self,
  206. dim: int,
  207. codebook_size: int,
  208. codebook_dim: tp.Optional[int] = None,
  209. decay: float = 0.99,
  210. epsilon: float = 1e-5,
  211. kmeans_init: bool = True,
  212. kmeans_iters: int = 50,
  213. threshold_ema_dead_code: int = 2,
  214. commitment_weight: float = 1.0,
  215. ):
  216. super().__init__()
  217. _codebook_dim: int = default(codebook_dim, dim)
  218. requires_projection = _codebook_dim != dim
  219. self.project_in = (
  220. nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
  221. )
  222. self.project_out = (
  223. nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
  224. )
  225. self.epsilon = epsilon
  226. self.commitment_weight = commitment_weight
  227. self._codebook = EuclideanCodebook(
  228. dim=_codebook_dim,
  229. codebook_size=codebook_size,
  230. kmeans_init=kmeans_init,
  231. kmeans_iters=kmeans_iters,
  232. decay=decay,
  233. epsilon=epsilon,
  234. threshold_ema_dead_code=threshold_ema_dead_code,
  235. )
  236. self.codebook_size = codebook_size
  237. @property
  238. def codebook(self):
  239. return self._codebook.embed
  240. def encode(self, x):
  241. x = rearrange(x, "b d n -> b n d")
  242. x = self.project_in(x)
  243. embed_in = self._codebook.encode(x)
  244. return embed_in
  245. def decode(self, embed_ind):
  246. quantize = self._codebook.decode(embed_ind)
  247. quantize = self.project_out(quantize)
  248. quantize = rearrange(quantize, "b n d -> b d n")
  249. return quantize
  250. def forward(self, x):
  251. device = x.device
  252. x = rearrange(x, "b d n -> b n d")
  253. x = self.project_in(x)
  254. quantize, embed_ind = self._codebook(x)
  255. if self.training:
  256. quantize = x + (quantize - x).detach()
  257. loss = torch.tensor([0.0], device=device, requires_grad=self.training)
  258. if self.training:
  259. if self.commitment_weight > 0:
  260. commit_loss = F.mse_loss(quantize.detach(), x)
  261. loss = loss + commit_loss * self.commitment_weight
  262. quantize = self.project_out(quantize)
  263. quantize = rearrange(quantize, "b n d -> b d n")
  264. return quantize, embed_ind, loss
  265. class ResidualVectorQuantization(nn.Module):
  266. """Residual vector quantization implementation.
  267. Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
  268. """
  269. def __init__(self, *, num_quantizers, **kwargs):
  270. super().__init__()
  271. self.layers = nn.ModuleList(
  272. [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
  273. )
  274. def forward(
  275. self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
  276. ):
  277. quantized_out = 0.0
  278. residual = x
  279. all_losses = []
  280. all_indices = []
  281. out_quantized = []
  282. n_q = n_q or len(self.layers)
  283. for i, layer in enumerate(self.layers[:n_q]):
  284. quantized, indices, loss = layer(residual)
  285. residual = residual - quantized
  286. quantized_out = quantized_out + quantized
  287. all_indices.append(indices)
  288. all_losses.append(loss)
  289. if layers and i in layers:
  290. out_quantized.append(quantized)
  291. out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
  292. return quantized_out, out_indices, out_losses, out_quantized
  293. def encode(
  294. self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
  295. ) -> torch.Tensor:
  296. residual = x
  297. all_indices = []
  298. n_q = n_q or len(self.layers)
  299. st = st or 0
  300. for layer in self.layers[st:n_q]:
  301. indices = layer.encode(residual)
  302. quantized = layer.decode(indices)
  303. residual = residual - quantized
  304. all_indices.append(indices)
  305. out_indices = torch.stack(all_indices)
  306. return out_indices
  307. def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
  308. quantized_out = torch.tensor(0.0, device=q_indices.device)
  309. for i, indices in enumerate(q_indices):
  310. layer = self.layers[st + i]
  311. quantized = layer.decode(indices)
  312. quantized_out = quantized_out + quantized
  313. return quantized_out