Yijunmaverick vor 7 Jahren
Commit
45a2472acc

+ 50 - 0
.gitignore

@@ -0,0 +1,50 @@
+_Store
+debug*
+datasets/
+checkpoints/
+results/
+build/
+dist/
+*.png
+torch.egg-info/
+*/runs/
+*/__pycache/
+*/**/__pycache__
+torch/version.py
+torch/csrc/generic/TensorMethods.cpp
+torch/lib/*.so*
+torch/lib/*.dylib*
+torch/lib/*.h
+torch/lib/build
+torch/lib/tmp_install
+torch/lib/include
+torch/lib/torch_shm_manager
+torch/csrc/cudnn/cuDNN.cpp
+torch/csrc/nn/THNN.cwrap
+torch/csrc/nn/THNN.cpp
+torch/csrc/nn/THCUNN.cwrap
+torch/csrc/nn/THCUNN.cpp
+torch/csrc/nn/THNN_generic.cwrap
+torch/csrc/nn/THNN_generic.cpp
+torch/csrc/nn/THNN_generic.h
+docs/src/**/*
+test/data/legacy_modules.t7
+test/data/gpu_tensors.pt
+test/htmlcov
+test/.coverage
+*/*.pyc
+*/**/*.pyc
+*/**/**/*.pyc
+*/**/**/**/*.pyc
+*/**/**/**/**/*.pyc
+*/*.so*
+*/**/*.so*
+*/**/*.dylib*
+test/data/legacy_serialized.pt
+*~
+.idea
+DAVIS/
+pretrained_model/*.t7
+pretrained_model/*.pth
+model_convert/
+

+ 177 - 0
network/Transformer.py

@@ -0,0 +1,177 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Transformer(nn.Module):
+    def __init__(self):
+        super(Transformer, self).__init__()
+        #
+        self.refpad01_1 = nn.ReflectionPad2d(3)
+        self.conv01_1 = nn.Conv2d(3, 64, 7)
+        self.in01_1 = InstanceNormalization(64)
+        # relu
+        self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1)
+        self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1)
+        self.in02_1 = InstanceNormalization(128)
+        # relu
+        self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1)
+        self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1)   
+        self.in03_1 = InstanceNormalization(256)    
+        # relu
+
+        ## res block 1
+        self.refpad04_1 = nn.ReflectionPad2d(1)
+        self.conv04_1 = nn.Conv2d(256, 256, 3)
+        self.in04_1 = InstanceNormalization(256)
+        # relu
+        self.refpad04_2 = nn.ReflectionPad2d(1)
+        self.conv04_2 = nn.Conv2d(256, 256, 3)
+        self.in04_2 = InstanceNormalization(256)
+        # + input
+
+        ## res block 2
+        self.refpad05_1 = nn.ReflectionPad2d(1)
+        self.conv05_1 = nn.Conv2d(256, 256, 3)
+        self.in05_1 = InstanceNormalization(256)
+        # relu
+        self.refpad05_2 = nn.ReflectionPad2d(1)
+        self.conv05_2 = nn.Conv2d(256, 256, 3)
+        self.in05_2 = InstanceNormalization(256)
+        # + input
+
+        ## res block 3
+        self.refpad06_1 = nn.ReflectionPad2d(1)
+        self.conv06_1 = nn.Conv2d(256, 256, 3)
+        self.in06_1 = InstanceNormalization(256)
+        # relu
+        self.refpad06_2 = nn.ReflectionPad2d(1)
+        self.conv06_2 = nn.Conv2d(256, 256, 3)
+        self.in06_2 = InstanceNormalization(256)
+        # + input
+
+        ## res block 4
+        self.refpad07_1 = nn.ReflectionPad2d(1)
+        self.conv07_1 = nn.Conv2d(256, 256, 3)
+        self.in07_1 = InstanceNormalization(256)
+        # relu
+        self.refpad07_2 = nn.ReflectionPad2d(1)
+        self.conv07_2 = nn.Conv2d(256, 256, 3)
+        self.in07_2 = InstanceNormalization(256)
+        # + input
+
+        ## res block 5
+        self.refpad08_1 = nn.ReflectionPad2d(1)
+        self.conv08_1 = nn.Conv2d(256, 256, 3)
+        self.in08_1 = InstanceNormalization(256)
+        # relu
+        self.refpad08_2 = nn.ReflectionPad2d(1)
+        self.conv08_2 = nn.Conv2d(256, 256, 3)
+        self.in08_2 = InstanceNormalization(256)
+        # + input
+
+        ## res block 6
+        self.refpad09_1 = nn.ReflectionPad2d(1)
+        self.conv09_1 = nn.Conv2d(256, 256, 3)
+        self.in09_1 = InstanceNormalization(256)
+        # relu
+        self.refpad09_2 = nn.ReflectionPad2d(1)
+        self.conv09_2 = nn.Conv2d(256, 256, 3)
+        self.in09_2 = InstanceNormalization(256)
+        # + input
+
+        ## res block 7
+        self.refpad10_1 = nn.ReflectionPad2d(1)
+        self.conv10_1 = nn.Conv2d(256, 256, 3)
+        self.in10_1 = InstanceNormalization(256)
+        # relu
+        self.refpad10_2 = nn.ReflectionPad2d(1)
+        self.conv10_2 = nn.Conv2d(256, 256, 3)
+        self.in10_2 = InstanceNormalization(256)
+        # + input
+
+        ## res block 8
+        self.refpad11_1 = nn.ReflectionPad2d(1)
+        self.conv11_1 = nn.Conv2d(256, 256, 3)
+        self.in11_1 = InstanceNormalization(256)
+        # relu
+        self.refpad11_2 = nn.ReflectionPad2d(1)
+        self.conv11_2 = nn.Conv2d(256, 256, 3)
+        self.in11_2 = InstanceNormalization(256)
+        # + input
+
+        ##------------------------------------##
+        self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
+        self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1)
+        self.in12_1 = InstanceNormalization(128)
+        # relu
+        self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
+        self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1)
+        self.in13_1 = InstanceNormalization(64)
+        # relu
+        self.refpad12_1 = nn.ReflectionPad2d(3)
+        self.deconv03_1 = nn.Conv2d(64, 3, 7)
+        # tanh
+
+    def forward(self, x):
+        y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x))))
+        y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y))))
+        t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y))))
+
+        ##
+        y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04))))
+        t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04
+
+        y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05))))
+        t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05
+
+        y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06))))
+        t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06
+
+        y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07))))
+        t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07
+
+        y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08))))
+        t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08
+
+        y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09))))
+        t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09
+
+        y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10))))
+        t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10
+
+        y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11))))
+        y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11
+        ##
+
+        y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y))))
+        y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y))))
+        y = F.tanh(self.deconv03_1(self.refpad12_1(y)))
+
+        return y
+
+
+class InstanceNormalization(nn.Module):
+    def __init__(self, dim, eps=1e-9):
+        super(InstanceNormalization, self).__init__()
+        self.scale = nn.Parameter(torch.FloatTensor(dim))
+        self.shift = nn.Parameter(torch.FloatTensor(dim))
+        self.eps = eps
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        self.scale.data.uniform_()
+        self.shift.data.zero_()
+
+    def __call__(self, x):
+        n = x.size(2) * x.size(3)
+        t = x.view(x.size(0), x.size(1), n)
+        mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
+        # Calculate the biased var. torch.var returns unbiased var
+        var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((n - 1) / float(n))
+        scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+        scale_broadcast = scale_broadcast.expand_as(x)
+        shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+        shift_broadcast = shift_broadcast.expand_as(x)
+        out = (x - mean) / torch.sqrt(var + self.eps)
+        out = out * scale_broadcast + shift_broadcast
+        return out

+ 0 - 0
network/__init__.py


+ 8 - 0
pretrained_model/download_pth.sh

@@ -0,0 +1,8 @@
+cd pretrained_model
+
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/pytorch_pth/Hayao_net_G_float.pth
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/pytorch_pth/Hosoda_net_G_float.pth
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/pytorch_pth/Paprika_net_G_float.pth	
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/pytorch_pth/Shinkai_net_G_float.pth
+
+cd ..

+ 8 - 0
pretrained_model/download_t7.sh

@@ -0,0 +1,8 @@
+cd pretrained_model
+
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/torch_t7/Hayao_net_G_float.t7
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/torch_t7/Hosoda_net_G_float.t7
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/torch_t7/Paprika_net_G_float.t7	
+wget -c http://vllab1.ucmerced.edu/~yli62/CartoonGAN/torch_t7/Shinkai_net_G_float.t7
+
+cd ..

+ 98 - 0
src/InstanceNormalization.lua

@@ -0,0 +1,98 @@
+require 'nn'
+
+_ = [[
+   An implementation for https://arxiv.org/abs/1607.08022
+]]
+
+local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module')
+
+function InstanceNormalization:__init(nOutput, eps, momentum, affine)
+   parent.__init(self)
+   self.running_mean = torch.zeros(nOutput)
+   self.running_var = torch.ones(nOutput)
+
+   self.eps = eps or 1e-5
+   self.momentum = momentum or 0.0
+   if affine ~= nil then
+      assert(type(affine) == 'boolean', 'affine has to be true/false')
+      self.affine = affine
+   else
+      self.affine = true
+   end
+   
+   self.nOutput = nOutput
+   self.prev_batch_size = -1
+
+   if self.affine then 
+      self.weight = torch.Tensor(nOutput):uniform()
+      self.bias = torch.Tensor(nOutput):zero()
+      self.gradWeight = torch.Tensor(nOutput)
+      self.gradBias = torch.Tensor(nOutput)
+   end 
+end
+
+function InstanceNormalization:updateOutput(input)
+   self.output = self.output or input.new()
+   assert(input:size(2) == self.nOutput)
+
+   local batch_size = input:size(1)
+   
+   if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type())  then
+      self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine)
+      self.bn:type(self:type())
+      self.bn.running_mean:copy(self.running_mean:repeatTensor(batch_size))
+      self.bn.running_var:copy(self.running_var:repeatTensor(batch_size))
+
+      self.prev_batch_size = input:size(1)
+   end
+
+   -- Get statistics
+   self.running_mean:copy(self.bn.running_mean:view(input:size(1),self.nOutput):mean(1))
+   self.running_var:copy(self.bn.running_var:view(input:size(1),self.nOutput):mean(1))
+
+   -- Set params for BN
+   if self.affine then
+      self.bn.weight:copy(self.weight:repeatTensor(batch_size))
+      self.bn.bias:copy(self.bias:repeatTensor(batch_size))
+   end
+
+   local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
+   self.output = self.bn:forward(input_1obj):viewAs(input)
+   
+   return self.output
+end
+
+function InstanceNormalization:updateGradInput(input, gradOutput)
+   self.gradInput = self.gradInput or gradOutput.new()
+
+   assert(self.bn)
+
+   local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) 
+   local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) 
+   
+   if self.affine then
+      self.bn.gradWeight:zero()
+      self.bn.gradBias:zero()
+   end
+
+   self.gradInput = self.bn:backward(input_1obj, gradOutput_1obj):viewAs(input)
+
+   if self.affine then
+      self.gradWeight:add(self.bn.gradWeight:view(input:size(1),self.nOutput):sum(1))
+      self.gradBias:add(self.bn.gradBias:view(input:size(1),self.nOutput):sum(1))
+   end
+   return self.gradInput
+end
+
+function InstanceNormalization:clearState()
+   self.output = self.output.new()
+   self.gradInput = self.gradInput.new()
+   
+   self.bn:clearState()
+end
+
+function InstanceNormalization:evaluate()
+end
+
+function InstanceNormalization:training()
+end

+ 90 - 0
src/util.lua

@@ -0,0 +1,90 @@
+--
+-- code derived from https://github.com/soumith/dcgan.torch
+--
+
+local util = {}
+
+require 'torch'
+require 'nn'
+require 'lfs'
+
+-- Preprocesses an image before passing it to a net
+-- Converts from RGB to BGR and rescales from [0,1] to [-1,1]
+function util.preprocess(img)
+    -- RGB to BGR
+    local perm = torch.LongTensor{3, 2, 1}
+    img = img:index(1, perm)
+    
+    -- [0,1] to [-1,1]
+    img = img:mul(2):add(-1)
+    
+    -- check that input is in expected range
+    assert(img:max()<=1,"badly scaled inputs")
+    assert(img:min()>=-1,"badly scaled inputs")
+    
+    return img
+end
+
+-- Undo the above preprocessing.
+function util.deprocess(img)
+    -- BGR to RGB
+    local perm = torch.LongTensor{3, 2, 1}
+    img = img:index(1, perm)
+    
+    -- [-1,1] to [0,1]
+    
+    img = img:add(1):div(2)
+    
+    return img
+end
+
+function util.preprocess_batch(batch)
+  for i = 1, batch:size(1) do
+    batch[i] = util.preprocess(batch[i]:squeeze())
+  end
+  return batch
+end
+
+function util.deprocess_batch(batch)
+  for i = 1, batch:size(1) do
+   batch[i] = util.deprocess(batch[i]:squeeze())
+  end
+return batch
+end
+
+--
+-- code derived from AdaIN https://github.com/xunhuang1995/AdaIN-style
+--
+
+function util.extractImageNamesRecursive(dir)
+    local files = {}
+    print("Extracting image paths: " .. dir)
+  
+    local function browseFolder(root, pathTable)
+        for entity in lfs.dir(root) do
+            if entity~="." and entity~=".." then
+                local fullPath=root..'/'..entity
+                local mode=lfs.attributes(fullPath,"mode")
+                if mode=="file" then
+                    local filepath = paths.concat(root, entity)
+  
+                    if string.find(filepath, 'jpg$')
+                    or string.find(filepath, 'png$')
+                    or string.find(filepath, 'jpeg$')
+                    or string.find(filepath, 'JPEG$')
+                    or string.find(filepath, 'ppm$') then
+                        table.insert(pathTable, filepath)
+                    end
+                elseif mode=="directory" then
+                    browseFolder(fullPath, pathTable);
+                end
+            end
+        end
+    end
+
+    browseFolder(dir, files)
+    return files
+end
+
+
+return util

+ 77 - 0
test.lua

@@ -0,0 +1,77 @@
+require 'cutorch'
+require 'nn'
+require 'cunn'
+require 'image'
+require 'optim'
+require 'nngraph'
+require 'paths'
+require 'src/InstanceNormalization'
+util = paths.dofile('src/util.lua')
+
+local cmd = torch.CmdLine()
+
+cmd:option('-input_dir', 'test_img');
+cmd:option('-output_dir', 'test_output', 'Path to save stylized image.')
+cmd:option('-load_size', 450)
+cmd:option('-gpu', 0, '-1 for CPU mode')
+cmd:option('-model_path', './pretrained_model/')
+cmd:option('-style', 'Hosoda')
+
+opt = cmd:parse(arg)
+
+if paths.dirp(opt.output_dir) then
+else
+    paths.mkdir(opt.output_dir)
+end
+
+if opt.gpu > -1 then
+  cutorch.setDevice(opt.gpu+1)
+end
+
+-- Define model
+local model = torch.load(paths.concat(opt.model_path .. opt.style .. '_net_G_float.t7'))
+model:evaluate()
+if opt.gpu > -1 then
+  print('GPU mode')
+  model:cuda()
+else
+  print('CPU mode')
+  model:float()
+end
+
+contentPaths = {}
+if opt.input_dir ~= '' then 
+  contentPaths = util.extractImageNamesRecursive(opt.input_dir)
+else
+  print('Please specify the input dierectory')
+end
+
+for i=1, #contentPaths do
+  local contentPath = contentPaths[i]
+  local contentExt = paths.extname(contentPath)
+  local contentName = paths.basename(contentPath, contentExt)
+  -- load image
+	local img = image.load(contentPath, 3, 'float')
+  -- resize image, keep aspect ratio
+	img = image.scale(img, opt.load_size, 'bilinear')
+	sg = img:size()
+	local input = nil
+  if opt.gpu > -1 then
+    input = torch.zeros(1, sg[1], sg[2], sg[3]):cuda()
+    input[1] = img
+  else
+    input = torch.zeros(1, sg[1], sg[2], sg[3]):float()
+    input[1] = img
+  end
+  -- forward
+	local out = util.deprocess_batch(model:forward(util.preprocess_batch(input)))
+  -- save
+	local savePath = paths.concat(opt.output_dir, contentName .. '_' .. opt.style .. '.' .. contentExt)
+  image.save(savePath, out[1])
+	collectgarbage()
+end
+print('Done!')
+
+
+
+

+ 74 - 0
test.py

@@ -0,0 +1,74 @@
+import torch
+import os
+import numpy as np
+import argparse
+from PIL import Image
+import torchvision.transforms as transforms
+from torch.autograd import Variable
+import torchvision.utils as vutils
+from network.Transformer import Transformer
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--input_dir', default = 'test_img')
+parser.add_argument('--load_size', default = 450)
+parser.add_argument('--model_path', default = './pretrained_model')
+parser.add_argument('--style', default = 'Hayao')
+parser.add_argument('--output_dir', default = 'test_output')
+parser.add_argument('--gpu', type=int, default = 0)
+
+opt = parser.parse_args()
+
+valid_ext = ['.jpg', '.png']
+
+if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)
+
+# load pretrained model
+model = Transformer()
+model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
+model.eval()
+print(opt.gpu)
+if opt.gpu > -1:
+	print('GPU mode')
+	model.cuda()
+else:
+	print('CPU mode')
+	model.float()
+
+for files in os.listdir(opt.input_dir):
+	ext = os.path.splitext(files)[1]
+	if ext not in valid_ext:
+		continue
+	# load image
+	input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
+	# resize image, keep aspect ratio
+	h = input_image.size[0]
+	w = input_image.size[1]
+	ratio = h *1.0 / w
+	if ratio > 1:
+		h = opt.load_size
+		w = int(h*1.0/ratio)
+	else:
+		w = opt.load_size
+		h = int(w * ratio)
+	input_image = input_image.resize((h, w), Image.BICUBIC)
+	input_image = np.asarray(input_image)
+	# RGB -> BGR
+	input_image = input_image[:, :, [2, 1, 0]]
+	input_image = transforms.ToTensor()(input_image).unsqueeze(0)
+	# preprocess, (-1, 1)
+	input_image = -1 + 2 * input_image 
+	if opt.gpu > -1:
+		input_image = Variable(input_image).cuda()
+	else:
+		input_image = Variable(input_image).float()
+	# forward
+	output_image = model(input_image)
+	output_image = output_image[0]
+	# BGR -> RGB
+	output_image = output_image[[2, 1, 0], :, :]
+	# deprocess, (0, 1)
+	output_image = output_image.data.cpu().float() * 0.5 + 0.5
+	# save
+	vutils.save_image(output_image, os.path.join(opt.output_dir, files[:-4] + '_' + opt.style + '.jpg'))
+
+print('Done!')

BIN
test_img/15--324.jpg


BIN
test_img/4--24.jpg


BIN
test_img/5--26.jpg


BIN
test_img/6--267.jpg


BIN
test_img/7--129.jpg


BIN
test_img/7--136.jpg


BIN
test_img/7--165.jpg


BIN
test_img/7--88.jpg


BIN
test_img/sjtu.jpg


BIN
test_img/wuda--2.jpg


BIN
test_img/wuda--3.jpg


BIN
test_output/15--324_Hayao.jpg


BIN
test_output/15--324_Hosoda.jpg


BIN
test_output/4--24_Hayao.jpg


BIN
test_output/4--24_Hosoda.jpg


BIN
test_output/5--26_Hayao.jpg


BIN
test_output/5--26_Hosoda.jpg


BIN
test_output/6--267_Hayao.jpg


BIN
test_output/6--267_Hosoda.jpg


BIN
test_output/7--129_Hayao.jpg


BIN
test_output/7--129_Hosoda.jpg


BIN
test_output/7--136_Hayao.jpg


BIN
test_output/7--136_Hosoda.jpg


BIN
test_output/7--165_Hayao.jpg


BIN
test_output/7--165_Hosoda.jpg


BIN
test_output/7--88_Hayao.jpg


BIN
test_output/7--88_Hosoda.jpg


BIN
test_output/sjtu_Hayao.jpg


BIN
test_output/sjtu_Hosoda.jpg


BIN
test_output/wuda--2_Hayao.jpg


BIN
test_output/wuda--2_Hosoda.jpg


BIN
test_output/wuda--3_Hayao.jpg


BIN
test_output/wuda--3_Hosoda.jpg