Transformer.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Transformer(nn.Module):
  5. def __init__(self):
  6. super(Transformer, self).__init__()
  7. #
  8. self.refpad01_1 = nn.ReflectionPad2d(3)
  9. self.conv01_1 = nn.Conv2d(3, 64, 7)
  10. self.in01_1 = InstanceNormalization(64)
  11. # relu
  12. self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1)
  13. self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1)
  14. self.in02_1 = InstanceNormalization(128)
  15. # relu
  16. self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1)
  17. self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1)
  18. self.in03_1 = InstanceNormalization(256)
  19. # relu
  20. ## res block 1
  21. self.refpad04_1 = nn.ReflectionPad2d(1)
  22. self.conv04_1 = nn.Conv2d(256, 256, 3)
  23. self.in04_1 = InstanceNormalization(256)
  24. # relu
  25. self.refpad04_2 = nn.ReflectionPad2d(1)
  26. self.conv04_2 = nn.Conv2d(256, 256, 3)
  27. self.in04_2 = InstanceNormalization(256)
  28. # + input
  29. ## res block 2
  30. self.refpad05_1 = nn.ReflectionPad2d(1)
  31. self.conv05_1 = nn.Conv2d(256, 256, 3)
  32. self.in05_1 = InstanceNormalization(256)
  33. # relu
  34. self.refpad05_2 = nn.ReflectionPad2d(1)
  35. self.conv05_2 = nn.Conv2d(256, 256, 3)
  36. self.in05_2 = InstanceNormalization(256)
  37. # + input
  38. ## res block 3
  39. self.refpad06_1 = nn.ReflectionPad2d(1)
  40. self.conv06_1 = nn.Conv2d(256, 256, 3)
  41. self.in06_1 = InstanceNormalization(256)
  42. # relu
  43. self.refpad06_2 = nn.ReflectionPad2d(1)
  44. self.conv06_2 = nn.Conv2d(256, 256, 3)
  45. self.in06_2 = InstanceNormalization(256)
  46. # + input
  47. ## res block 4
  48. self.refpad07_1 = nn.ReflectionPad2d(1)
  49. self.conv07_1 = nn.Conv2d(256, 256, 3)
  50. self.in07_1 = InstanceNormalization(256)
  51. # relu
  52. self.refpad07_2 = nn.ReflectionPad2d(1)
  53. self.conv07_2 = nn.Conv2d(256, 256, 3)
  54. self.in07_2 = InstanceNormalization(256)
  55. # + input
  56. ## res block 5
  57. self.refpad08_1 = nn.ReflectionPad2d(1)
  58. self.conv08_1 = nn.Conv2d(256, 256, 3)
  59. self.in08_1 = InstanceNormalization(256)
  60. # relu
  61. self.refpad08_2 = nn.ReflectionPad2d(1)
  62. self.conv08_2 = nn.Conv2d(256, 256, 3)
  63. self.in08_2 = InstanceNormalization(256)
  64. # + input
  65. ## res block 6
  66. self.refpad09_1 = nn.ReflectionPad2d(1)
  67. self.conv09_1 = nn.Conv2d(256, 256, 3)
  68. self.in09_1 = InstanceNormalization(256)
  69. # relu
  70. self.refpad09_2 = nn.ReflectionPad2d(1)
  71. self.conv09_2 = nn.Conv2d(256, 256, 3)
  72. self.in09_2 = InstanceNormalization(256)
  73. # + input
  74. ## res block 7
  75. self.refpad10_1 = nn.ReflectionPad2d(1)
  76. self.conv10_1 = nn.Conv2d(256, 256, 3)
  77. self.in10_1 = InstanceNormalization(256)
  78. # relu
  79. self.refpad10_2 = nn.ReflectionPad2d(1)
  80. self.conv10_2 = nn.Conv2d(256, 256, 3)
  81. self.in10_2 = InstanceNormalization(256)
  82. # + input
  83. ## res block 8
  84. self.refpad11_1 = nn.ReflectionPad2d(1)
  85. self.conv11_1 = nn.Conv2d(256, 256, 3)
  86. self.in11_1 = InstanceNormalization(256)
  87. # relu
  88. self.refpad11_2 = nn.ReflectionPad2d(1)
  89. self.conv11_2 = nn.Conv2d(256, 256, 3)
  90. self.in11_2 = InstanceNormalization(256)
  91. # + input
  92. ##------------------------------------##
  93. self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
  94. self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1)
  95. self.in12_1 = InstanceNormalization(128)
  96. # relu
  97. self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
  98. self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1)
  99. self.in13_1 = InstanceNormalization(64)
  100. # relu
  101. self.refpad12_1 = nn.ReflectionPad2d(3)
  102. self.deconv03_1 = nn.Conv2d(64, 3, 7)
  103. # tanh
  104. def forward(self, x):
  105. y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x))))
  106. y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y))))
  107. t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y))))
  108. ##
  109. y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04))))
  110. t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04
  111. y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05))))
  112. t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05
  113. y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06))))
  114. t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06
  115. y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07))))
  116. t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07
  117. y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08))))
  118. t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08
  119. y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09))))
  120. t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09
  121. y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10))))
  122. t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10
  123. y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11))))
  124. y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11
  125. ##
  126. y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y))))
  127. y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y))))
  128. y = F.tanh(self.deconv03_1(self.refpad12_1(y)))
  129. return y
  130. class InstanceNormalization(nn.Module):
  131. def __init__(self, dim, eps=1e-9):
  132. super(InstanceNormalization, self).__init__()
  133. self.scale = nn.Parameter(torch.FloatTensor(dim))
  134. self.shift = nn.Parameter(torch.FloatTensor(dim))
  135. self.eps = eps
  136. self._reset_parameters()
  137. def _reset_parameters(self):
  138. self.scale.data.uniform_()
  139. self.shift.data.zero_()
  140. def __call__(self, x):
  141. n = x.size(2) * x.size(3)
  142. t = x.view(x.size(0), x.size(1), n)
  143. mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
  144. # Calculate the biased var. torch.var returns unbiased var
  145. var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((n - 1) / float(n))
  146. scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0)
  147. scale_broadcast = scale_broadcast.expand_as(x)
  148. shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0)
  149. shift_broadcast = shift_broadcast.expand_as(x)
  150. out = (x - mean) / torch.sqrt(var + self.eps)
  151. out = out * scale_broadcast + shift_broadcast
  152. return out