test.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import torch
  2. import os
  3. import numpy as np
  4. import argparse
  5. from PIL import Image
  6. import torchvision.transforms as transforms
  7. from torch.autograd import Variable
  8. import torchvision.utils as vutils
  9. from network.Transformer import Transformer
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--input_dir', default = 'test_img')
  12. parser.add_argument('--load_size', default = 450)
  13. parser.add_argument('--model_path', default = './pretrained_model')
  14. parser.add_argument('--style', default = 'Hayao')
  15. parser.add_argument('--output_dir', default = 'test_output')
  16. parser.add_argument('--gpu', type=int, default = 0)
  17. opt = parser.parse_args()
  18. valid_ext = ['.jpg', '.png']
  19. if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)
  20. # load pretrained model
  21. model = Transformer()
  22. model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
  23. model.eval()
  24. print(opt.gpu)
  25. if opt.gpu > -1:
  26. print('GPU mode')
  27. model.cuda()
  28. else:
  29. print('CPU mode')
  30. model.float()
  31. for files in os.listdir(opt.input_dir):
  32. ext = os.path.splitext(files)[1]
  33. if ext not in valid_ext:
  34. continue
  35. # load image
  36. input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
  37. # resize image, keep aspect ratio
  38. h = input_image.size[0]
  39. w = input_image.size[1]
  40. ratio = h *1.0 / w
  41. if ratio > 1:
  42. h = opt.load_size
  43. w = int(h*1.0/ratio)
  44. else:
  45. w = opt.load_size
  46. h = int(w * ratio)
  47. input_image = input_image.resize((h, w), Image.BICUBIC)
  48. input_image = np.asarray(input_image)
  49. # RGB -> BGR
  50. input_image = input_image[:, :, [2, 1, 0]]
  51. input_image = transforms.ToTensor()(input_image).unsqueeze(0)
  52. # preprocess, (-1, 1)
  53. input_image = -1 + 2 * input_image
  54. if opt.gpu > -1:
  55. input_image = Variable(input_image).cuda()
  56. else:
  57. input_image = Variable(input_image).float()
  58. # forward
  59. output_image = model(input_image)
  60. output_image = output_image[0]
  61. # BGR -> RGB
  62. output_image = output_image[[2, 1, 0], :, :]
  63. # deprocess, (0, 1)
  64. output_image = output_image.data.cpu().float() * 0.5 + 0.5
  65. # save
  66. vutils.save_image(output_image, os.path.join(opt.output_dir, files[:-4] + '_' + opt.style + '.jpg'))
  67. print('Done!')