mingpt.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. """
  2. taken from: https://github.com/karpathy/minGPT/
  3. GPT model:
  4. - the initial stem consists of a combination of token encoding and a positional encoding
  5. - the meat of it is a uniform sequence of Transformer blocks
  6. - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
  7. - all blocks feed into a central residual pathway similar to resnets
  8. - the final decoder is a linear projection into a vanilla Softmax classifier
  9. """
  10. import math
  11. import logging
  12. import torch
  13. import torch.nn as nn
  14. from torch.nn import functional as F
  15. from transformers import top_k_top_p_filtering
  16. logger = logging.getLogger(__name__)
  17. class GPTConfig:
  18. """ base GPT config, params common to all GPT versions """
  19. embd_pdrop = 0.1
  20. resid_pdrop = 0.1
  21. attn_pdrop = 0.1
  22. def __init__(self, vocab_size, block_size, **kwargs):
  23. self.vocab_size = vocab_size
  24. self.block_size = block_size
  25. for k,v in kwargs.items():
  26. setattr(self, k, v)
  27. class GPT1Config(GPTConfig):
  28. """ GPT-1 like network roughly 125M params """
  29. n_layer = 12
  30. n_head = 12
  31. n_embd = 768
  32. class CausalSelfAttention(nn.Module):
  33. """
  34. A vanilla multi-head masked self-attention layer with a projection at the end.
  35. It is possible to use torch.nn.MultiheadAttention here but I am including an
  36. explicit implementation here to show that there is nothing too scary here.
  37. """
  38. def __init__(self, config):
  39. super().__init__()
  40. assert config.n_embd % config.n_head == 0
  41. # key, query, value projections for all heads
  42. self.key = nn.Linear(config.n_embd, config.n_embd)
  43. self.query = nn.Linear(config.n_embd, config.n_embd)
  44. self.value = nn.Linear(config.n_embd, config.n_embd)
  45. # regularization
  46. self.attn_drop = nn.Dropout(config.attn_pdrop)
  47. self.resid_drop = nn.Dropout(config.resid_pdrop)
  48. # output projection
  49. self.proj = nn.Linear(config.n_embd, config.n_embd)
  50. # causal mask to ensure that attention is only applied to the left in the input sequence
  51. mask = torch.tril(torch.ones(config.block_size,
  52. config.block_size))
  53. if hasattr(config, "n_unmasked"):
  54. mask[:config.n_unmasked, :config.n_unmasked] = 1
  55. self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
  56. self.n_head = config.n_head
  57. def forward(self, x, layer_past=None):
  58. B, T, C = x.size()
  59. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  60. k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  61. q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  62. v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  63. present = torch.stack((k, v))
  64. if layer_past is not None:
  65. past_key, past_value = layer_past
  66. k = torch.cat((past_key, k), dim=-2)
  67. v = torch.cat((past_value, v), dim=-2)
  68. # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
  69. att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
  70. if layer_past is None:
  71. att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
  72. att = F.softmax(att, dim=-1)
  73. att = self.attn_drop(att)
  74. y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
  75. y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
  76. # output projection
  77. y = self.resid_drop(self.proj(y))
  78. return y, present # TODO: check that this does not break anything
  79. class Block(nn.Module):
  80. """ an unassuming Transformer block """
  81. def __init__(self, config):
  82. super().__init__()
  83. self.ln1 = nn.LayerNorm(config.n_embd)
  84. self.ln2 = nn.LayerNorm(config.n_embd)
  85. self.attn = CausalSelfAttention(config)
  86. self.mlp = nn.Sequential(
  87. nn.Linear(config.n_embd, 4 * config.n_embd),
  88. nn.GELU(), # nice
  89. nn.Linear(4 * config.n_embd, config.n_embd),
  90. nn.Dropout(config.resid_pdrop),
  91. )
  92. def forward(self, x, layer_past=None, return_present=False):
  93. # TODO: check that training still works
  94. if return_present: assert not self.training
  95. # layer past: tuple of length two with B, nh, T, hs
  96. attn, present = self.attn(self.ln1(x), layer_past=layer_past)
  97. x = x + attn
  98. x = x + self.mlp(self.ln2(x))
  99. if layer_past is not None or return_present:
  100. return x, present
  101. return x
  102. class GPT(nn.Module):
  103. """ the full GPT language model, with a context size of block_size """
  104. def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
  105. embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
  106. super().__init__()
  107. config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
  108. embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
  109. n_layer=n_layer, n_head=n_head, n_embd=n_embd,
  110. n_unmasked=n_unmasked)
  111. # input embedding stem
  112. self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
  113. self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
  114. self.drop = nn.Dropout(config.embd_pdrop)
  115. # transformer
  116. self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
  117. # decoder head
  118. self.ln_f = nn.LayerNorm(config.n_embd)
  119. self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  120. self.block_size = config.block_size
  121. self.apply(self._init_weights)
  122. self.config = config
  123. logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
  124. def get_block_size(self):
  125. return self.block_size
  126. def _init_weights(self, module):
  127. if isinstance(module, (nn.Linear, nn.Embedding)):
  128. module.weight.data.normal_(mean=0.0, std=0.02)
  129. if isinstance(module, nn.Linear) and module.bias is not None:
  130. module.bias.data.zero_()
  131. elif isinstance(module, nn.LayerNorm):
  132. module.bias.data.zero_()
  133. module.weight.data.fill_(1.0)
  134. def forward(self, idx, embeddings=None, targets=None):
  135. # forward the GPT model
  136. token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
  137. if embeddings is not None: # prepend explicit embeddings
  138. token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
  139. t = token_embeddings.shape[1]
  140. assert t <= self.block_size, "Cannot forward, model block size is exhausted."
  141. position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
  142. x = self.drop(token_embeddings + position_embeddings)
  143. x = self.blocks(x)
  144. x = self.ln_f(x)
  145. logits = self.head(x)
  146. # if we are given some desired targets also calculate the loss
  147. loss = None
  148. if targets is not None:
  149. loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
  150. return logits, loss
  151. def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
  152. # inference only
  153. assert not self.training
  154. token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
  155. if embeddings is not None: # prepend explicit embeddings
  156. token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
  157. if past is not None:
  158. assert past_length is not None
  159. past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
  160. past_shape = list(past.shape)
  161. expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
  162. assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
  163. position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
  164. else:
  165. position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
  166. x = self.drop(token_embeddings + position_embeddings)
  167. presents = [] # accumulate over layers
  168. for i, block in enumerate(self.blocks):
  169. x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
  170. presents.append(present)
  171. x = self.ln_f(x)
  172. logits = self.head(x)
  173. # if we are given some desired targets also calculate the loss
  174. loss = None
  175. if targets is not None:
  176. loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
  177. return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
  178. class DummyGPT(nn.Module):
  179. # for debugging
  180. def __init__(self, add_value=1):
  181. super().__init__()
  182. self.add_value = add_value
  183. def forward(self, idx):
  184. return idx + self.add_value, None
  185. class CodeGPT(nn.Module):
  186. """Takes in semi-embeddings"""
  187. def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
  188. embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
  189. super().__init__()
  190. config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
  191. embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
  192. n_layer=n_layer, n_head=n_head, n_embd=n_embd,
  193. n_unmasked=n_unmasked)
  194. # input embedding stem
  195. self.tok_emb = nn.Linear(in_channels, config.n_embd)
  196. self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
  197. self.drop = nn.Dropout(config.embd_pdrop)
  198. # transformer
  199. self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
  200. # decoder head
  201. self.ln_f = nn.LayerNorm(config.n_embd)
  202. self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  203. self.block_size = config.block_size
  204. self.apply(self._init_weights)
  205. self.config = config
  206. logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
  207. def get_block_size(self):
  208. return self.block_size
  209. def _init_weights(self, module):
  210. if isinstance(module, (nn.Linear, nn.Embedding)):
  211. module.weight.data.normal_(mean=0.0, std=0.02)
  212. if isinstance(module, nn.Linear) and module.bias is not None:
  213. module.bias.data.zero_()
  214. elif isinstance(module, nn.LayerNorm):
  215. module.bias.data.zero_()
  216. module.weight.data.fill_(1.0)
  217. def forward(self, idx, embeddings=None, targets=None):
  218. # forward the GPT model
  219. token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
  220. if embeddings is not None: # prepend explicit embeddings
  221. token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
  222. t = token_embeddings.shape[1]
  223. assert t <= self.block_size, "Cannot forward, model block size is exhausted."
  224. position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
  225. x = self.drop(token_embeddings + position_embeddings)
  226. x = self.blocks(x)
  227. x = self.taming_cinln_f(x)
  228. logits = self.head(x)
  229. # if we are given some desired targets also calculate the loss
  230. loss = None
  231. if targets is not None:
  232. loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
  233. return logits, loss
  234. #### sampling utils
  235. def top_k_logits(logits, k):
  236. v, ix = torch.topk(logits, k)
  237. out = logits.clone()
  238. out[out < v[:, [-1]]] = -float('Inf')
  239. return out
  240. @torch.no_grad()
  241. def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
  242. """
  243. take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
  244. the sequence, feeding the predictions back into the model each time. Clearly the sampling
  245. has quadratic complexity unlike an RNN that is only linear, and has a finite context window
  246. of block_size, unlike an RNN that has an infinite context window.
  247. """
  248. block_size = model.get_block_size()
  249. model.eval()
  250. for k in range(steps):
  251. x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
  252. logits, _ = model(x_cond)
  253. # pluck the logits at the final step and scale by temperature
  254. logits = logits[:, -1, :] / temperature
  255. # optionally crop probabilities to only the top k options
  256. if top_k is not None:
  257. logits = top_k_logits(logits, top_k)
  258. # apply softmax to convert to probabilities
  259. probs = F.softmax(logits, dim=-1)
  260. # sample from the distribution or take the most likely
  261. if sample:
  262. ix = torch.multinomial(probs, num_samples=1)
  263. else:
  264. _, ix = torch.topk(probs, k=1, dim=-1)
  265. # append to the sequence and continue
  266. x = torch.cat((x, ix), dim=1)
  267. return x
  268. @torch.no_grad()
  269. def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
  270. top_k=None, top_p=None, callback=None):
  271. # x is conditioning
  272. sample = x
  273. cond_len = x.shape[1]
  274. past = None
  275. for n in range(steps):
  276. if callback is not None:
  277. callback(n)
  278. logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
  279. if past is None:
  280. past = [present]
  281. else:
  282. past.append(present)
  283. logits = logits[:, -1, :] / temperature
  284. if top_k is not None:
  285. logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
  286. probs = F.softmax(logits, dim=-1)
  287. if not sample_logits:
  288. _, x = torch.topk(probs, k=1, dim=-1)
  289. else:
  290. x = torch.multinomial(probs, num_samples=1)
  291. # append to the sequence and continue
  292. sample = torch.cat((sample, x), dim=1)
  293. del past
  294. sample = sample[:, cond_len:] # cut conditioning off
  295. return sample
  296. #### clustering utils
  297. class KMeans(nn.Module):
  298. def __init__(self, ncluster=512, nc=3, niter=10):
  299. super().__init__()
  300. self.ncluster = ncluster
  301. self.nc = nc
  302. self.niter = niter
  303. self.shape = (3,32,32)
  304. self.register_buffer("C", torch.zeros(self.ncluster,nc))
  305. self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
  306. def is_initialized(self):
  307. return self.initialized.item() == 1
  308. @torch.no_grad()
  309. def initialize(self, x):
  310. N, D = x.shape
  311. assert D == self.nc, D
  312. c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
  313. for i in range(self.niter):
  314. # assign all pixels to the closest codebook element
  315. a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
  316. # move each codebook element to be the mean of the pixels that assigned to it
  317. c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
  318. # re-assign any poorly positioned codebook elements
  319. nanix = torch.any(torch.isnan(c), dim=1)
  320. ndead = nanix.sum().item()
  321. print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
  322. c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
  323. self.C.copy_(c)
  324. self.initialized.fill_(1)
  325. def forward(self, x, reverse=False, shape=None):
  326. if not reverse:
  327. # flatten
  328. bs,c,h,w = x.shape
  329. assert c == self.nc
  330. x = x.reshape(bs,c,h*w,1)
  331. C = self.C.permute(1,0)
  332. C = C.reshape(1,c,1,self.ncluster)
  333. a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
  334. return a
  335. else:
  336. # flatten
  337. bs, HW = x.shape
  338. """
  339. c = self.C.reshape( 1, self.nc, 1, self.ncluster)
  340. c = c[bs*[0],:,:,:]
  341. c = c[:,:,HW*[0],:]
  342. x = x.reshape(bs, 1, HW, 1)
  343. x = x[:,3*[0],:,:]
  344. x = torch.gather(c, dim=3, index=x)
  345. """
  346. x = self.C[x]
  347. x = x.permute(0,2,1)
  348. shape = shape if shape is not None else self.shape
  349. x = x.reshape(bs, *shape)
  350. return x