comixify/style_transfer/style_transfer.py
Maciej Pęśko d3e6b4974d Add minor config changes
- Model caching
- Production settings
2018-09-06 21:27:40 +02:00

67 lines
2.2 KiB
Python

import os
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from django.conf import settings
from django.core.cache import cache
from torch.autograd import Variable
from CartoonGAN.network.Transformer import Transformer
class StyleTransfer():
@classmethod
def get_stylized_frames(cls, frames, method="cartoon_gan", gpu=settings.GPU, **kwargs):
if method == "cartoon_gan":
return cls._cartoon_gan_stylize(frames, gpu=gpu, **kwargs)
@staticmethod
def _cartoon_gan_stylize(frames, gpu=True, **kwargs):
style = kwargs.get("style", "Hayao")
resize = kwargs.get("resize", 450)
model_cache_key = 'model_cache'
model = cache.get(model_cache_key) # get model from cache
if model is None:
# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join("CartoonGAN/pretrained_model", style + "_net_G_float.pth")))
model.eval()
model.cuda() if gpu else model.float()
cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever
stylized_imgs = []
for img in frames:
# resize image, keep aspect ratio
h, w, _ = img.shape
ratio = h * 1.0 / w
if ratio > 1:
h = resize
w = int(h * 1.0 / ratio)
else:
w = resize
h = int(w * ratio)
input_image = cv2.resize(img, (w, h))
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
# preprocess, (-1, 1)
input_image = -1 + 2 * input_image
input_image = Variable(input_image).cuda() if gpu else Variable(input_image).float()
# forward
output_image = model(input_image)
output_image = output_image[0]
# deprocess, (0, 1)
output_image = (output_image.data.cpu().float() * 0.5 + 0.5).numpy()
# switch channels -> (c, h, w) -> (h, w, c)
output_image = np.rollaxis(output_image, 0, 3)
# append image to result images
stylized_imgs.append(output_image)
return stylized_imgs