test.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import torch
  2. import os
  3. import numpy as np
  4. import argparse
  5. from PIL import Image
  6. import torchvision.transforms as transforms
  7. from torch.autograd import Variable
  8. import torchvision.utils as vutils
  9. from network.Transformer import Transformer
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--input_dir', default = 'test_img')
  12. parser.add_argument('--load_size', default = 450)
  13. parser.add_argument('--model_path', default = './pretrained_model')
  14. parser.add_argument('--style', default = 'Hayao')
  15. parser.add_argument('--output_dir', default = 'test_output')
  16. parser.add_argument('--gpu', type=int, default = 0)
  17. opt = parser.parse_args()
  18. valid_ext = ['.jpg', '.png']
  19. if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)
  20. # load pretrained model
  21. model = Transformer()
  22. model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
  23. model.eval()
  24. if opt.gpu > -1:
  25. print('GPU mode')
  26. model.cuda()
  27. else:
  28. print('CPU mode')
  29. model.float()
  30. for files in os.listdir(opt.input_dir):
  31. ext = os.path.splitext(files)[1]
  32. if ext not in valid_ext:
  33. continue
  34. # load image
  35. input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
  36. # resize image, keep aspect ratio
  37. h = input_image.size[0]
  38. w = input_image.size[1]
  39. ratio = h *1.0 / w
  40. if ratio > 1:
  41. h = opt.load_size
  42. w = int(h*1.0/ratio)
  43. else:
  44. w = opt.load_size
  45. h = int(w * ratio)
  46. input_image = input_image.resize((h, w), Image.BICUBIC)
  47. input_image = np.asarray(input_image)
  48. # RGB -> BGR
  49. input_image = input_image[:, :, [2, 1, 0]]
  50. input_image = transforms.ToTensor()(input_image).unsqueeze(0)
  51. # preprocess, (-1, 1)
  52. input_image = -1 + 2 * input_image
  53. if opt.gpu > -1:
  54. input_image = Variable(input_image, volatile=True).cuda()
  55. else:
  56. input_image = Variable(input_image, volatile=True).float()
  57. # forward
  58. output_image = model(input_image)
  59. output_image = output_image[0]
  60. # BGR -> RGB
  61. output_image = output_image[[2, 1, 0], :, :]
  62. # deprocess, (0, 1)
  63. output_image = output_image.data.cpu().float() * 0.5 + 0.5
  64. # save
  65. vutils.save_image(output_image, os.path.join(opt.output_dir, files[:-4] + '_' + opt.style + '.jpg'))
  66. print('Done!')