test.lua 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. require 'cutorch'
  2. require 'nn'
  3. require 'cunn'
  4. require 'image'
  5. require 'optim'
  6. require 'nngraph'
  7. require 'paths'
  8. require 'src/InstanceNormalization'
  9. util = paths.dofile('src/util.lua')
  10. local cmd = torch.CmdLine()
  11. cmd:option('-input_dir', 'test_img');
  12. cmd:option('-output_dir', 'test_output', 'Path to save stylized image.')
  13. cmd:option('-load_size', 450)
  14. cmd:option('-gpu', 0, '-1 for CPU mode')
  15. cmd:option('-model_path', './pretrained_model/')
  16. cmd:option('-style', 'Hosoda')
  17. opt = cmd:parse(arg)
  18. if paths.dirp(opt.output_dir) then
  19. else
  20. paths.mkdir(opt.output_dir)
  21. end
  22. if opt.gpu > -1 then
  23. cutorch.setDevice(opt.gpu+1)
  24. end
  25. -- Define model
  26. local model = torch.load(paths.concat(opt.model_path .. opt.style .. '_net_G_float.t7'))
  27. model:evaluate()
  28. if opt.gpu > -1 then
  29. print('GPU mode')
  30. model:cuda()
  31. else
  32. print('CPU mode')
  33. model:float()
  34. end
  35. contentPaths = {}
  36. if opt.input_dir ~= '' then
  37. contentPaths = util.extractImageNamesRecursive(opt.input_dir)
  38. else
  39. print('Please specify the input dierectory')
  40. end
  41. for i=1, #contentPaths do
  42. local contentPath = contentPaths[i]
  43. local contentExt = paths.extname(contentPath)
  44. local contentName = paths.basename(contentPath, contentExt)
  45. -- load image
  46. local img = image.load(contentPath, 3, 'float')
  47. -- resize image, keep aspect ratio
  48. img = image.scale(img, opt.load_size, 'bilinear')
  49. sg = img:size()
  50. local input = nil
  51. if opt.gpu > -1 then
  52. input = torch.zeros(1, sg[1], sg[2], sg[3]):cuda()
  53. input[1] = img
  54. else
  55. input = torch.zeros(1, sg[1], sg[2], sg[3]):float()
  56. input[1] = img
  57. end
  58. -- forward
  59. local out = util.deprocess_batch(model:forward(util.preprocess_batch(input)))
  60. -- save
  61. local savePath = paths.concat(opt.output_dir, contentName .. '_' .. opt.style .. '.' .. contentExt)
  62. image.save(savePath, out[1])
  63. collectgarbage()
  64. end
  65. print('Done!')