mirror of
https://github.com/maciej3031/comixify.git
synced 2026-03-11 08:54:35 +00:00
* Add endpoint to getting images * Add migration * Fix minor bug * Fix another bug * Fix problem with CartoonGAN not change model and add timestamp to Comic * Add minimum response delay Update sample
106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
import os
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
from django.conf import settings
|
|
from django.core.cache import cache
|
|
from torch.autograd import Variable
|
|
|
|
from CartoonGAN.network.Transformer import Transformer
|
|
from ComixGAN.model import ComixGAN
|
|
from utils import profile
|
|
|
|
# load pretrained model
|
|
comixGAN = ComixGAN()
|
|
|
|
|
|
class StyleTransfer():
|
|
@classmethod
|
|
@profile
|
|
def get_stylized_frames(cls, frames, style_transfer_mode=0, gpu=settings.GPU):
|
|
if style_transfer_mode == 0:
|
|
return cls._comix_gan_stylize(frames=frames)
|
|
elif style_transfer_mode == 1:
|
|
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hayao')
|
|
elif style_transfer_mode == 2:
|
|
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hosoda')
|
|
|
|
@staticmethod
|
|
def _resize_images(frames, size=384):
|
|
resized_images = []
|
|
for img in frames:
|
|
# resize image, keep aspect ratio
|
|
h, w, _ = img.shape
|
|
ratio = h / w
|
|
if ratio > 1:
|
|
h = size
|
|
w = int(h * 1.0 / ratio)
|
|
else:
|
|
w = size
|
|
h = int(w * ratio)
|
|
resized_img = cv2.resize(img, (w, h))
|
|
resized_images.append(resized_img)
|
|
return resized_images
|
|
|
|
@classmethod
|
|
def _comix_gan_stylize(cls, frames):
|
|
if max(frames[0].shape) > settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER:
|
|
frames = cls._resize_images(frames, size=settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER)
|
|
|
|
with comixGAN.graph.as_default():
|
|
with comixGAN.session.as_default():
|
|
batch_size = 4
|
|
stylized_imgs = []
|
|
for i in range(0, len(frames), batch_size):
|
|
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
|
|
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
|
|
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 1.25))
|
|
|
|
return list(np.concatenate(stylized_imgs, axis=0))
|
|
|
|
@classmethod
|
|
def _cartoon_gan_stylize(cls, frames, gpu=True, style='Hayao'):
|
|
if style == 'Hayao':
|
|
model_cache_key = 'model_cache_hayao'
|
|
model = cache.get(model_cache_key) # get model from cache
|
|
|
|
elif style == 'Hosoda':
|
|
model_cache_key = 'model_cache_hosoda'
|
|
model = cache.get(model_cache_key) # get model from cache
|
|
|
|
else:
|
|
raise Exception('No such CartoonGAN model!')
|
|
|
|
if model is None:
|
|
# load pretrained model
|
|
model = Transformer()
|
|
model.load_state_dict(torch.load(os.path.join("CartoonGAN/pretrained_model", style + "_net_G_float.pth")))
|
|
model.eval()
|
|
model.cuda() if gpu else model.float()
|
|
cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever
|
|
|
|
frames = cls._resize_images(frames, size=450)
|
|
stylized_imgs = []
|
|
for img in frames:
|
|
input_image = transforms.ToTensor()(img).unsqueeze(0)
|
|
|
|
# preprocess, (-1, 1)
|
|
input_image = -1 + 2 * input_image
|
|
input_image = Variable(input_image).cuda() if gpu else Variable(input_image).float()
|
|
|
|
# forward
|
|
output_image = model(input_image)
|
|
output_image = output_image[0]
|
|
|
|
# deprocess, (0, 1)
|
|
output_image = (output_image.data.cpu().float() * 0.5 + 0.5).numpy()
|
|
|
|
# switch channels -> (c, h, w) -> (h, w, c)
|
|
output_image = np.rollaxis(output_image, 0, 3)
|
|
|
|
# append image to result images
|
|
stylized_imgs.append(255 * output_image)
|
|
|
|
return stylized_imgs
|