1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- import torch
- import os
- import numpy as np
- import argparse
- from PIL import Image
- import torchvision.transforms as transforms
- from torch.autograd import Variable
- import torchvision.utils as vutils
- from network.Transformer import Transformer
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_dir', default = 'test_img')
- parser.add_argument('--load_size', default = 450)
- parser.add_argument('--model_path', default = './pretrained_model')
- parser.add_argument('--style', default = 'Hayao')
- parser.add_argument('--output_dir', default = 'test_output')
- parser.add_argument('--gpu', type=int, default = 0)
- opt = parser.parse_args()
- valid_ext = ['.jpg', '.png']
- if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)
- # load pretrained model
- model = Transformer()
- model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
- model.eval()
- print(opt.gpu)
- if opt.gpu > -1:
- print('GPU mode')
- model.cuda()
- else:
- print('CPU mode')
- model.float()
- for files in os.listdir(opt.input_dir):
- ext = os.path.splitext(files)[1]
- if ext not in valid_ext:
- continue
- # load image
- input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
- # resize image, keep aspect ratio
- h = input_image.size[0]
- w = input_image.size[1]
- ratio = h *1.0 / w
- if ratio > 1:
- h = opt.load_size
- w = int(h*1.0/ratio)
- else:
- w = opt.load_size
- h = int(w * ratio)
- input_image = input_image.resize((h, w), Image.BICUBIC)
- input_image = np.asarray(input_image)
- # RGB -> BGR
- input_image = input_image[:, :, [2, 1, 0]]
- input_image = transforms.ToTensor()(input_image).unsqueeze(0)
- # preprocess, (-1, 1)
- input_image = -1 + 2 * input_image
- if opt.gpu > -1:
- input_image = Variable(input_image).cuda()
- else:
- input_image = Variable(input_image).float()
- # forward
- output_image = model(input_image)
- output_image = output_image[0]
- # BGR -> RGB
- output_image = output_image[[2, 1, 0], :, :]
- # deprocess, (0, 1)
- output_image = output_image.data.cpu().float() * 0.5 + 0.5
- # save
- vutils.save_image(output_image, os.path.join(opt.output_dir, files[:-4] + '_' + opt.style + '.jpg'))
- print('Done!')
|