model.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776
  1. # pytorch_diffusion + derived encoder decoder
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. def get_timestep_embedding(timesteps, embedding_dim):
  7. """
  8. This matches the implementation in Denoising Diffusion Probabilistic Models:
  9. From Fairseq.
  10. Build sinusoidal embeddings.
  11. This matches the implementation in tensor2tensor, but differs slightly
  12. from the description in Section 3.5 of "Attention Is All You Need".
  13. """
  14. assert len(timesteps.shape) == 1
  15. half_dim = embedding_dim // 2
  16. emb = math.log(10000) / (half_dim - 1)
  17. emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
  18. emb = emb.to(device=timesteps.device)
  19. emb = timesteps.float()[:, None] * emb[None, :]
  20. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  21. if embedding_dim % 2 == 1: # zero pad
  22. emb = torch.nn.functional.pad(emb, (0,1,0,0))
  23. return emb
  24. def nonlinearity(x):
  25. # swish
  26. return x*torch.sigmoid(x)
  27. def Normalize(in_channels):
  28. return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  29. class Upsample(nn.Module):
  30. def __init__(self, in_channels, with_conv):
  31. super().__init__()
  32. self.with_conv = with_conv
  33. if self.with_conv:
  34. self.conv = torch.nn.Conv2d(in_channels,
  35. in_channels,
  36. kernel_size=3,
  37. stride=1,
  38. padding=1)
  39. def forward(self, x):
  40. x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
  41. if self.with_conv:
  42. x = self.conv(x)
  43. return x
  44. class Downsample(nn.Module):
  45. def __init__(self, in_channels, with_conv):
  46. super().__init__()
  47. self.with_conv = with_conv
  48. if self.with_conv:
  49. # no asymmetric padding in torch conv, must do it ourselves
  50. self.conv = torch.nn.Conv2d(in_channels,
  51. in_channels,
  52. kernel_size=3,
  53. stride=2,
  54. padding=0)
  55. def forward(self, x):
  56. if self.with_conv:
  57. pad = (0,1,0,1)
  58. x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
  59. x = self.conv(x)
  60. else:
  61. x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
  62. return x
  63. class ResnetBlock(nn.Module):
  64. def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
  65. dropout, temb_channels=512):
  66. super().__init__()
  67. self.in_channels = in_channels
  68. out_channels = in_channels if out_channels is None else out_channels
  69. self.out_channels = out_channels
  70. self.use_conv_shortcut = conv_shortcut
  71. self.norm1 = Normalize(in_channels)
  72. self.conv1 = torch.nn.Conv2d(in_channels,
  73. out_channels,
  74. kernel_size=3,
  75. stride=1,
  76. padding=1)
  77. if temb_channels > 0:
  78. self.temb_proj = torch.nn.Linear(temb_channels,
  79. out_channels)
  80. self.norm2 = Normalize(out_channels)
  81. self.dropout = torch.nn.Dropout(dropout)
  82. self.conv2 = torch.nn.Conv2d(out_channels,
  83. out_channels,
  84. kernel_size=3,
  85. stride=1,
  86. padding=1)
  87. if self.in_channels != self.out_channels:
  88. if self.use_conv_shortcut:
  89. self.conv_shortcut = torch.nn.Conv2d(in_channels,
  90. out_channels,
  91. kernel_size=3,
  92. stride=1,
  93. padding=1)
  94. else:
  95. self.nin_shortcut = torch.nn.Conv2d(in_channels,
  96. out_channels,
  97. kernel_size=1,
  98. stride=1,
  99. padding=0)
  100. def forward(self, x, temb):
  101. h = x
  102. h = self.norm1(h)
  103. h = nonlinearity(h)
  104. h = self.conv1(h)
  105. if temb is not None:
  106. h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
  107. h = self.norm2(h)
  108. h = nonlinearity(h)
  109. h = self.dropout(h)
  110. h = self.conv2(h)
  111. if self.in_channels != self.out_channels:
  112. if self.use_conv_shortcut:
  113. x = self.conv_shortcut(x)
  114. else:
  115. x = self.nin_shortcut(x)
  116. return x+h
  117. class AttnBlock(nn.Module):
  118. def __init__(self, in_channels):
  119. super().__init__()
  120. self.in_channels = in_channels
  121. self.norm = Normalize(in_channels)
  122. self.q = torch.nn.Conv2d(in_channels,
  123. in_channels,
  124. kernel_size=1,
  125. stride=1,
  126. padding=0)
  127. self.k = torch.nn.Conv2d(in_channels,
  128. in_channels,
  129. kernel_size=1,
  130. stride=1,
  131. padding=0)
  132. self.v = torch.nn.Conv2d(in_channels,
  133. in_channels,
  134. kernel_size=1,
  135. stride=1,
  136. padding=0)
  137. self.proj_out = torch.nn.Conv2d(in_channels,
  138. in_channels,
  139. kernel_size=1,
  140. stride=1,
  141. padding=0)
  142. def forward(self, x):
  143. h_ = x
  144. h_ = self.norm(h_)
  145. q = self.q(h_)
  146. k = self.k(h_)
  147. v = self.v(h_)
  148. # compute attention
  149. b,c,h,w = q.shape
  150. q = q.reshape(b,c,h*w)
  151. q = q.permute(0,2,1) # b,hw,c
  152. k = k.reshape(b,c,h*w) # b,c,hw
  153. w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  154. w_ = w_ * (int(c)**(-0.5))
  155. w_ = torch.nn.functional.softmax(w_, dim=2)
  156. # attend to values
  157. v = v.reshape(b,c,h*w)
  158. w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
  159. h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  160. h_ = h_.reshape(b,c,h,w)
  161. h_ = self.proj_out(h_)
  162. return x+h_
  163. class Model(nn.Module):
  164. def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
  165. attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
  166. resolution, use_timestep=True):
  167. super().__init__()
  168. self.ch = ch
  169. self.temb_ch = self.ch*4
  170. self.num_resolutions = len(ch_mult)
  171. self.num_res_blocks = num_res_blocks
  172. self.resolution = resolution
  173. self.in_channels = in_channels
  174. self.use_timestep = use_timestep
  175. if self.use_timestep:
  176. # timestep embedding
  177. self.temb = nn.Module()
  178. self.temb.dense = nn.ModuleList([
  179. torch.nn.Linear(self.ch,
  180. self.temb_ch),
  181. torch.nn.Linear(self.temb_ch,
  182. self.temb_ch),
  183. ])
  184. # downsampling
  185. self.conv_in = torch.nn.Conv2d(in_channels,
  186. self.ch,
  187. kernel_size=3,
  188. stride=1,
  189. padding=1)
  190. curr_res = resolution
  191. in_ch_mult = (1,)+tuple(ch_mult)
  192. self.down = nn.ModuleList()
  193. for i_level in range(self.num_resolutions):
  194. block = nn.ModuleList()
  195. attn = nn.ModuleList()
  196. block_in = ch*in_ch_mult[i_level]
  197. block_out = ch*ch_mult[i_level]
  198. for i_block in range(self.num_res_blocks):
  199. block.append(ResnetBlock(in_channels=block_in,
  200. out_channels=block_out,
  201. temb_channels=self.temb_ch,
  202. dropout=dropout))
  203. block_in = block_out
  204. if curr_res in attn_resolutions:
  205. attn.append(AttnBlock(block_in))
  206. down = nn.Module()
  207. down.block = block
  208. down.attn = attn
  209. if i_level != self.num_resolutions-1:
  210. down.downsample = Downsample(block_in, resamp_with_conv)
  211. curr_res = curr_res // 2
  212. self.down.append(down)
  213. # middle
  214. self.mid = nn.Module()
  215. self.mid.block_1 = ResnetBlock(in_channels=block_in,
  216. out_channels=block_in,
  217. temb_channels=self.temb_ch,
  218. dropout=dropout)
  219. self.mid.attn_1 = AttnBlock(block_in)
  220. self.mid.block_2 = ResnetBlock(in_channels=block_in,
  221. out_channels=block_in,
  222. temb_channels=self.temb_ch,
  223. dropout=dropout)
  224. # upsampling
  225. self.up = nn.ModuleList()
  226. for i_level in reversed(range(self.num_resolutions)):
  227. block = nn.ModuleList()
  228. attn = nn.ModuleList()
  229. block_out = ch*ch_mult[i_level]
  230. skip_in = ch*ch_mult[i_level]
  231. for i_block in range(self.num_res_blocks+1):
  232. if i_block == self.num_res_blocks:
  233. skip_in = ch*in_ch_mult[i_level]
  234. block.append(ResnetBlock(in_channels=block_in+skip_in,
  235. out_channels=block_out,
  236. temb_channels=self.temb_ch,
  237. dropout=dropout))
  238. block_in = block_out
  239. if curr_res in attn_resolutions:
  240. attn.append(AttnBlock(block_in))
  241. up = nn.Module()
  242. up.block = block
  243. up.attn = attn
  244. if i_level != 0:
  245. up.upsample = Upsample(block_in, resamp_with_conv)
  246. curr_res = curr_res * 2
  247. self.up.insert(0, up) # prepend to get consistent order
  248. # end
  249. self.norm_out = Normalize(block_in)
  250. self.conv_out = torch.nn.Conv2d(block_in,
  251. out_ch,
  252. kernel_size=3,
  253. stride=1,
  254. padding=1)
  255. def forward(self, x, t=None):
  256. #assert x.shape[2] == x.shape[3] == self.resolution
  257. if self.use_timestep:
  258. # timestep embedding
  259. assert t is not None
  260. temb = get_timestep_embedding(t, self.ch)
  261. temb = self.temb.dense[0](temb)
  262. temb = nonlinearity(temb)
  263. temb = self.temb.dense[1](temb)
  264. else:
  265. temb = None
  266. # downsampling
  267. hs = [self.conv_in(x)]
  268. for i_level in range(self.num_resolutions):
  269. for i_block in range(self.num_res_blocks):
  270. h = self.down[i_level].block[i_block](hs[-1], temb)
  271. if len(self.down[i_level].attn) > 0:
  272. h = self.down[i_level].attn[i_block](h)
  273. hs.append(h)
  274. if i_level != self.num_resolutions-1:
  275. hs.append(self.down[i_level].downsample(hs[-1]))
  276. # middle
  277. h = hs[-1]
  278. h = self.mid.block_1(h, temb)
  279. h = self.mid.attn_1(h)
  280. h = self.mid.block_2(h, temb)
  281. # upsampling
  282. for i_level in reversed(range(self.num_resolutions)):
  283. for i_block in range(self.num_res_blocks+1):
  284. h = self.up[i_level].block[i_block](
  285. torch.cat([h, hs.pop()], dim=1), temb)
  286. if len(self.up[i_level].attn) > 0:
  287. h = self.up[i_level].attn[i_block](h)
  288. if i_level != 0:
  289. h = self.up[i_level].upsample(h)
  290. # end
  291. h = self.norm_out(h)
  292. h = nonlinearity(h)
  293. h = self.conv_out(h)
  294. return h
  295. class Encoder(nn.Module):
  296. def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
  297. attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
  298. resolution, z_channels, double_z=True, **ignore_kwargs):
  299. super().__init__()
  300. self.ch = ch
  301. self.temb_ch = 0
  302. self.num_resolutions = len(ch_mult)
  303. self.num_res_blocks = num_res_blocks
  304. self.resolution = resolution
  305. self.in_channels = in_channels
  306. # downsampling
  307. self.conv_in = torch.nn.Conv2d(in_channels,
  308. self.ch,
  309. kernel_size=3,
  310. stride=1,
  311. padding=1)
  312. curr_res = resolution
  313. in_ch_mult = (1,)+tuple(ch_mult)
  314. self.down = nn.ModuleList()
  315. for i_level in range(self.num_resolutions):
  316. block = nn.ModuleList()
  317. attn = nn.ModuleList()
  318. block_in = ch*in_ch_mult[i_level]
  319. block_out = ch*ch_mult[i_level]
  320. for i_block in range(self.num_res_blocks):
  321. block.append(ResnetBlock(in_channels=block_in,
  322. out_channels=block_out,
  323. temb_channels=self.temb_ch,
  324. dropout=dropout))
  325. block_in = block_out
  326. if curr_res in attn_resolutions:
  327. attn.append(AttnBlock(block_in))
  328. down = nn.Module()
  329. down.block = block
  330. down.attn = attn
  331. if i_level != self.num_resolutions-1:
  332. down.downsample = Downsample(block_in, resamp_with_conv)
  333. curr_res = curr_res // 2
  334. self.down.append(down)
  335. # middle
  336. self.mid = nn.Module()
  337. self.mid.block_1 = ResnetBlock(in_channels=block_in,
  338. out_channels=block_in,
  339. temb_channels=self.temb_ch,
  340. dropout=dropout)
  341. self.mid.attn_1 = AttnBlock(block_in)
  342. self.mid.block_2 = ResnetBlock(in_channels=block_in,
  343. out_channels=block_in,
  344. temb_channels=self.temb_ch,
  345. dropout=dropout)
  346. # end
  347. self.norm_out = Normalize(block_in)
  348. self.conv_out = torch.nn.Conv2d(block_in,
  349. 2*z_channels if double_z else z_channels,
  350. kernel_size=3,
  351. stride=1,
  352. padding=1)
  353. def forward(self, x):
  354. #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
  355. # timestep embedding
  356. temb = None
  357. # downsampling
  358. hs = [self.conv_in(x)]
  359. for i_level in range(self.num_resolutions):
  360. for i_block in range(self.num_res_blocks):
  361. h = self.down[i_level].block[i_block](hs[-1], temb)
  362. if len(self.down[i_level].attn) > 0:
  363. h = self.down[i_level].attn[i_block](h)
  364. hs.append(h)
  365. if i_level != self.num_resolutions-1:
  366. hs.append(self.down[i_level].downsample(hs[-1]))
  367. # middle
  368. h = hs[-1]
  369. h = self.mid.block_1(h, temb)
  370. h = self.mid.attn_1(h)
  371. h = self.mid.block_2(h, temb)
  372. # end
  373. h = self.norm_out(h)
  374. h = nonlinearity(h)
  375. h = self.conv_out(h)
  376. return h
  377. class Decoder(nn.Module):
  378. def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
  379. attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
  380. resolution, z_channels, give_pre_end=False, **ignorekwargs):
  381. super().__init__()
  382. self.ch = ch
  383. self.temb_ch = 0
  384. self.num_resolutions = len(ch_mult)
  385. self.num_res_blocks = num_res_blocks
  386. self.resolution = resolution
  387. self.in_channels = in_channels
  388. self.give_pre_end = give_pre_end
  389. # compute in_ch_mult, block_in and curr_res at lowest res
  390. in_ch_mult = (1,)+tuple(ch_mult)
  391. block_in = ch*ch_mult[self.num_resolutions-1]
  392. curr_res = resolution // 2**(self.num_resolutions-1)
  393. self.z_shape = (1,z_channels,curr_res,curr_res)
  394. print("Working with z of shape {} = {} dimensions.".format(
  395. self.z_shape, np.prod(self.z_shape)))
  396. # z to block_in
  397. self.conv_in = torch.nn.Conv2d(z_channels,
  398. block_in,
  399. kernel_size=3,
  400. stride=1,
  401. padding=1)
  402. # middle
  403. self.mid = nn.Module()
  404. self.mid.block_1 = ResnetBlock(in_channels=block_in,
  405. out_channels=block_in,
  406. temb_channels=self.temb_ch,
  407. dropout=dropout)
  408. self.mid.attn_1 = AttnBlock(block_in)
  409. self.mid.block_2 = ResnetBlock(in_channels=block_in,
  410. out_channels=block_in,
  411. temb_channels=self.temb_ch,
  412. dropout=dropout)
  413. # upsampling
  414. self.up = nn.ModuleList()
  415. for i_level in reversed(range(self.num_resolutions)):
  416. block = nn.ModuleList()
  417. attn = nn.ModuleList()
  418. block_out = ch*ch_mult[i_level]
  419. for i_block in range(self.num_res_blocks+1):
  420. block.append(ResnetBlock(in_channels=block_in,
  421. out_channels=block_out,
  422. temb_channels=self.temb_ch,
  423. dropout=dropout))
  424. block_in = block_out
  425. if curr_res in attn_resolutions:
  426. attn.append(AttnBlock(block_in))
  427. up = nn.Module()
  428. up.block = block
  429. up.attn = attn
  430. if i_level != 0:
  431. up.upsample = Upsample(block_in, resamp_with_conv)
  432. curr_res = curr_res * 2
  433. self.up.insert(0, up) # prepend to get consistent order
  434. # end
  435. self.norm_out = Normalize(block_in)
  436. self.conv_out = torch.nn.Conv2d(block_in,
  437. out_ch,
  438. kernel_size=3,
  439. stride=1,
  440. padding=1)
  441. def forward(self, z):
  442. #assert z.shape[1:] == self.z_shape[1:]
  443. self.last_z_shape = z.shape
  444. # timestep embedding
  445. temb = None
  446. # z to block_in
  447. h = self.conv_in(z)
  448. # middle
  449. h = self.mid.block_1(h, temb)
  450. h = self.mid.attn_1(h)
  451. h = self.mid.block_2(h, temb)
  452. # upsampling
  453. for i_level in reversed(range(self.num_resolutions)):
  454. for i_block in range(self.num_res_blocks+1):
  455. h = self.up[i_level].block[i_block](h, temb)
  456. if len(self.up[i_level].attn) > 0:
  457. h = self.up[i_level].attn[i_block](h)
  458. if i_level != 0:
  459. h = self.up[i_level].upsample(h)
  460. # end
  461. if self.give_pre_end:
  462. return h
  463. h = self.norm_out(h)
  464. h = nonlinearity(h)
  465. h = self.conv_out(h)
  466. return h
  467. class VUNet(nn.Module):
  468. def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
  469. attn_resolutions, dropout=0.0, resamp_with_conv=True,
  470. in_channels, c_channels,
  471. resolution, z_channels, use_timestep=False, **ignore_kwargs):
  472. super().__init__()
  473. self.ch = ch
  474. self.temb_ch = self.ch*4
  475. self.num_resolutions = len(ch_mult)
  476. self.num_res_blocks = num_res_blocks
  477. self.resolution = resolution
  478. self.use_timestep = use_timestep
  479. if self.use_timestep:
  480. # timestep embedding
  481. self.temb = nn.Module()
  482. self.temb.dense = nn.ModuleList([
  483. torch.nn.Linear(self.ch,
  484. self.temb_ch),
  485. torch.nn.Linear(self.temb_ch,
  486. self.temb_ch),
  487. ])
  488. # downsampling
  489. self.conv_in = torch.nn.Conv2d(c_channels,
  490. self.ch,
  491. kernel_size=3,
  492. stride=1,
  493. padding=1)
  494. curr_res = resolution
  495. in_ch_mult = (1,)+tuple(ch_mult)
  496. self.down = nn.ModuleList()
  497. for i_level in range(self.num_resolutions):
  498. block = nn.ModuleList()
  499. attn = nn.ModuleList()
  500. block_in = ch*in_ch_mult[i_level]
  501. block_out = ch*ch_mult[i_level]
  502. for i_block in range(self.num_res_blocks):
  503. block.append(ResnetBlock(in_channels=block_in,
  504. out_channels=block_out,
  505. temb_channels=self.temb_ch,
  506. dropout=dropout))
  507. block_in = block_out
  508. if curr_res in attn_resolutions:
  509. attn.append(AttnBlock(block_in))
  510. down = nn.Module()
  511. down.block = block
  512. down.attn = attn
  513. if i_level != self.num_resolutions-1:
  514. down.downsample = Downsample(block_in, resamp_with_conv)
  515. curr_res = curr_res // 2
  516. self.down.append(down)
  517. self.z_in = torch.nn.Conv2d(z_channels,
  518. block_in,
  519. kernel_size=1,
  520. stride=1,
  521. padding=0)
  522. # middle
  523. self.mid = nn.Module()
  524. self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
  525. out_channels=block_in,
  526. temb_channels=self.temb_ch,
  527. dropout=dropout)
  528. self.mid.attn_1 = AttnBlock(block_in)
  529. self.mid.block_2 = ResnetBlock(in_channels=block_in,
  530. out_channels=block_in,
  531. temb_channels=self.temb_ch,
  532. dropout=dropout)
  533. # upsampling
  534. self.up = nn.ModuleList()
  535. for i_level in reversed(range(self.num_resolutions)):
  536. block = nn.ModuleList()
  537. attn = nn.ModuleList()
  538. block_out = ch*ch_mult[i_level]
  539. skip_in = ch*ch_mult[i_level]
  540. for i_block in range(self.num_res_blocks+1):
  541. if i_block == self.num_res_blocks:
  542. skip_in = ch*in_ch_mult[i_level]
  543. block.append(ResnetBlock(in_channels=block_in+skip_in,
  544. out_channels=block_out,
  545. temb_channels=self.temb_ch,
  546. dropout=dropout))
  547. block_in = block_out
  548. if curr_res in attn_resolutions:
  549. attn.append(AttnBlock(block_in))
  550. up = nn.Module()
  551. up.block = block
  552. up.attn = attn
  553. if i_level != 0:
  554. up.upsample = Upsample(block_in, resamp_with_conv)
  555. curr_res = curr_res * 2
  556. self.up.insert(0, up) # prepend to get consistent order
  557. # end
  558. self.norm_out = Normalize(block_in)
  559. self.conv_out = torch.nn.Conv2d(block_in,
  560. out_ch,
  561. kernel_size=3,
  562. stride=1,
  563. padding=1)
  564. def forward(self, x, z):
  565. #assert x.shape[2] == x.shape[3] == self.resolution
  566. if self.use_timestep:
  567. # timestep embedding
  568. assert t is not None
  569. temb = get_timestep_embedding(t, self.ch)
  570. temb = self.temb.dense[0](temb)
  571. temb = nonlinearity(temb)
  572. temb = self.temb.dense[1](temb)
  573. else:
  574. temb = None
  575. # downsampling
  576. hs = [self.conv_in(x)]
  577. for i_level in range(self.num_resolutions):
  578. for i_block in range(self.num_res_blocks):
  579. h = self.down[i_level].block[i_block](hs[-1], temb)
  580. if len(self.down[i_level].attn) > 0:
  581. h = self.down[i_level].attn[i_block](h)
  582. hs.append(h)
  583. if i_level != self.num_resolutions-1:
  584. hs.append(self.down[i_level].downsample(hs[-1]))
  585. # middle
  586. h = hs[-1]
  587. z = self.z_in(z)
  588. h = torch.cat((h,z),dim=1)
  589. h = self.mid.block_1(h, temb)
  590. h = self.mid.attn_1(h)
  591. h = self.mid.block_2(h, temb)
  592. # upsampling
  593. for i_level in reversed(range(self.num_resolutions)):
  594. for i_block in range(self.num_res_blocks+1):
  595. h = self.up[i_level].block[i_block](
  596. torch.cat([h, hs.pop()], dim=1), temb)
  597. if len(self.up[i_level].attn) > 0:
  598. h = self.up[i_level].attn[i_block](h)
  599. if i_level != 0:
  600. h = self.up[i_level].upsample(h)
  601. # end
  602. h = self.norm_out(h)
  603. h = nonlinearity(h)
  604. h = self.conv_out(h)
  605. return h
  606. class SimpleDecoder(nn.Module):
  607. def __init__(self, in_channels, out_channels, *args, **kwargs):
  608. super().__init__()
  609. self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
  610. ResnetBlock(in_channels=in_channels,
  611. out_channels=2 * in_channels,
  612. temb_channels=0, dropout=0.0),
  613. ResnetBlock(in_channels=2 * in_channels,
  614. out_channels=4 * in_channels,
  615. temb_channels=0, dropout=0.0),
  616. ResnetBlock(in_channels=4 * in_channels,
  617. out_channels=2 * in_channels,
  618. temb_channels=0, dropout=0.0),
  619. nn.Conv2d(2*in_channels, in_channels, 1),
  620. Upsample(in_channels, with_conv=True)])
  621. # end
  622. self.norm_out = Normalize(in_channels)
  623. self.conv_out = torch.nn.Conv2d(in_channels,
  624. out_channels,
  625. kernel_size=3,
  626. stride=1,
  627. padding=1)
  628. def forward(self, x):
  629. for i, layer in enumerate(self.model):
  630. if i in [1,2,3]:
  631. x = layer(x, None)
  632. else:
  633. x = layer(x)
  634. h = self.norm_out(x)
  635. h = nonlinearity(h)
  636. x = self.conv_out(h)
  637. return x
  638. class UpsampleDecoder(nn.Module):
  639. def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
  640. ch_mult=(2,2), dropout=0.0):
  641. super().__init__()
  642. # upsampling
  643. self.temb_ch = 0
  644. self.num_resolutions = len(ch_mult)
  645. self.num_res_blocks = num_res_blocks
  646. block_in = in_channels
  647. curr_res = resolution // 2 ** (self.num_resolutions - 1)
  648. self.res_blocks = nn.ModuleList()
  649. self.upsample_blocks = nn.ModuleList()
  650. for i_level in range(self.num_resolutions):
  651. res_block = []
  652. block_out = ch * ch_mult[i_level]
  653. for i_block in range(self.num_res_blocks + 1):
  654. res_block.append(ResnetBlock(in_channels=block_in,
  655. out_channels=block_out,
  656. temb_channels=self.temb_ch,
  657. dropout=dropout))
  658. block_in = block_out
  659. self.res_blocks.append(nn.ModuleList(res_block))
  660. if i_level != self.num_resolutions - 1:
  661. self.upsample_blocks.append(Upsample(block_in, True))
  662. curr_res = curr_res * 2
  663. # end
  664. self.norm_out = Normalize(block_in)
  665. self.conv_out = torch.nn.Conv2d(block_in,
  666. out_channels,
  667. kernel_size=3,
  668. stride=1,
  669. padding=1)
  670. def forward(self, x):
  671. # upsampling
  672. h = x
  673. for k, i_level in enumerate(range(self.num_resolutions)):
  674. for i_block in range(self.num_res_blocks + 1):
  675. h = self.res_blocks[i_level][i_block](h, None)
  676. if i_level != self.num_resolutions - 1:
  677. h = self.upsample_blocks[k](h)
  678. h = self.norm_out(h)
  679. h = nonlinearity(h)
  680. h = self.conv_out(h)
  681. return h