test.lua 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. require 'cutorch'
  2. require 'nn'
  3. require 'cunn'
  4. require 'image'
  5. require 'nngraph'
  6. require 'paths'
  7. require 'src/InstanceNormalization'
  8. util = paths.dofile('src/util.lua')
  9. local cmd = torch.CmdLine()
  10. cmd:option('-input_dir', 'test_img');
  11. cmd:option('-output_dir', 'test_output', 'Path to save stylized image.')
  12. cmd:option('-load_size', 450)
  13. cmd:option('-gpu', 0, '-1 for CPU mode')
  14. cmd:option('-model_path', './pretrained_model/')
  15. cmd:option('-style', 'Hosoda')
  16. opt = cmd:parse(arg)
  17. if paths.dirp(opt.output_dir) then
  18. else
  19. paths.mkdir(opt.output_dir)
  20. end
  21. if opt.gpu > -1 then
  22. cutorch.setDevice(opt.gpu+1)
  23. end
  24. -- Define model
  25. local model = torch.load(paths.concat(opt.model_path .. opt.style .. '_net_G_float.t7'))
  26. model:evaluate()
  27. if opt.gpu > -1 then
  28. print('GPU mode')
  29. model:cuda()
  30. else
  31. print('CPU mode')
  32. model:float()
  33. end
  34. contentPaths = {}
  35. if opt.input_dir ~= '' then
  36. contentPaths = util.extractImageNamesRecursive(opt.input_dir)
  37. else
  38. print('Please specify the input dierectory')
  39. end
  40. for i=1, #contentPaths do
  41. local contentPath = contentPaths[i]
  42. local contentExt = paths.extname(contentPath)
  43. local contentName = paths.basename(contentPath, contentExt)
  44. -- load image
  45. local img = image.load(contentPath, 3, 'float')
  46. -- resize image, keep aspect ratio
  47. img = image.scale(img, opt.load_size, 'bilinear')
  48. sg = img:size()
  49. local input = nil
  50. if opt.gpu > -1 then
  51. input = torch.zeros(1, sg[1], sg[2], sg[3]):cuda()
  52. input[1] = img
  53. else
  54. input = torch.zeros(1, sg[1], sg[2], sg[3]):float()
  55. input[1] = img
  56. end
  57. -- forward
  58. local out = util.deprocess_batch(model:forward(util.preprocess_batch(input)))
  59. -- save
  60. local savePath = paths.concat(opt.output_dir, contentName .. '_' .. opt.style .. '.' .. contentExt)
  61. image.save(savePath, out[1])
  62. collectgarbage()
  63. end
  64. print('Done!')