Add comix gan (#16)

* Add ComixGAN #1

* Add minor fixes

* Reduce batch size

* Fix concatenate bug

* Fix GPU memory not released problem #2

* Build client

* Fix style_transfer_mode bug

* Improve timings

* Add minor fix with GPU name

* Fix tf session problem

* Compile frontend

* Change parameters

* Fix occasional  yt Error
This commit is contained in:
Maciej Pęśko 2018-11-10 22:45:55 +01:00 committed by GitHub
parent 75810f7c84
commit 1e5252b8f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 352 additions and 212 deletions

24
ComixGAN/model.py Normal file
View file

@ -0,0 +1,24 @@
import errno
import os
import tensorflow as tf
from django.conf import settings
from keras.models import load_model
from keras_contrib.layers import InstanceNormalization
class ComixGAN:
def __init__(self):
if not os.path.exists(settings.COMIX_GAN_MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.COMIX_GAN_MODEL_PATH)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.6
config.gpu_options.allow_growth = True
self.session = tf.Session(graph=self.graph, config=config)
with self.graph.as_default():
with self.session.as_default():
with tf.device('/device:GPU:0'):
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
custom_objects={'InstanceNormalization': InstanceNormalization})

Binary file not shown.

Binary file not shown.

View file

@ -23,7 +23,7 @@ class Video(models.Model):
yt_pafy = pafy.new(yt_url)
# Use the biggest possible quality with file size < MAX_FILE_SIZE and resolution <= 480px
for stream in yt_pafy.videostreams:
for stream in reversed(yt_pafy.videostreams):
if stream.get_filesize() < settings.MAX_FILE_SIZE and int(stream.quality.split("x")[1]) <= 480:
tmp_name = uuid.uuid4().hex + ".mp4"
relative_path = jj('raw_videos', tmp_name)
@ -34,22 +34,24 @@ class Video(models.Model):
else:
raise TooLargeFile()
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0):
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0, style_transfer_mode=0):
(keyframes, keyframes_timings), keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
video=self,
frames_mode=frames_mode,
rl_mode=rl_mode,
image_assessment_mode=image_assessment_mode
)
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes)
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes,
style_transfer_mode=style_transfer_mode)
comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes)
timings = {
'keyframes_extraction_time': keyframes_extraction_time,
'stylization_time': stylization_time,
'layout_generation_time': layout_generation_time,
**keyframes_timings
'keyframes_extraction_time_details': keyframes_timings
}
return comic_image, timings
@ -61,7 +63,7 @@ class Comic(models.Model):
@profile
def create_from_nparray(cls, nparray_file, video):
if nparray_file.max() <= 1:
nparray_file = (nparray_file * 255).astype(int)
nparray_file = (nparray_file).astype(int)
tmp_name = uuid.uuid4().hex + ".png"
cv2.imwrite(jj(settings.TMP_DIR, tmp_name), nparray_file)
with open(jj(settings.TMP_DIR, tmp_name), mode="rb") as tmp_file:

View file

@ -8,7 +8,10 @@ class VideoSerializer(serializers.Serializer):
file = serializers.FileField()
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
default=settings.DEFAULT_STYLE_TRANSFER_MODE)
def validate(self, attrs):
file = attrs.get("file")
@ -23,4 +26,7 @@ class YouTubeDownloadSerializer(serializers.Serializer):
url = serializers.URLField()
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
default=settings.DEFAULT_STYLE_TRANSFER_MODE)

View file

@ -22,7 +22,8 @@ class Comixify(APIView):
comic_image, timings = video.create_comic(
frames_mode=serializer.validated_data["frames_mode"],
rl_mode=serializer.validated_data["rl_mode"],
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
image_assessment_mode=serializer.validated_data["image_assessment_mode"],
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
)
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
timings['from_nparray_time'] = from_nparray_time
@ -54,7 +55,8 @@ class ComixifyFromYoutube(APIView):
comic_image, timings = video.create_comic(
frames_mode=serializer.validated_data["frames_mode"],
rl_mode=serializer.validated_data["rl_mode"],
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
image_assessment_mode=serializer.validated_data["image_assessment_mode"],
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
)
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
timings['from_nparray_time'] = from_nparray_time

View file

@ -28,6 +28,6 @@ class LayoutGenerator():
def _pad_images(frames):
padded_result_imgs = []
for img in frames:
padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(1, 1, 1))
padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(255, 255, 255))
padded_result_imgs.append(padded_img)
return padded_result_imgs

View file

@ -10,10 +10,7 @@ RUN apt-get update && apt-get install -y apt-utils software-properties-common &&
libsnappy-dev protobuf-compiler \
python-numpy python-setuptools python-scipy \
libavformat-dev libswscale-dev unzip && \
python3.6 -m pip install --upgrade pip && \
python3.6 -m pip install jupyter ipywidgets jupyterlab && \
python3.6 -m pip install h5py keras && \
python3.6 -m pip install scikit-image opencv-contrib-python pyyaml
python3.6 -m pip install --upgrade pip
RUN mkdir /comixify
COPY ./Makefile.config /comixify/Makefile.config
@ -45,7 +42,8 @@ RUN echo "$CAFFE_ROOT/build/lib" >> /etc/ld.so.conf.d/caffe.conf && ldconfig &&
WORKDIR /comixify
COPY . /comixify
RUN unzip popularity/pretrained_model/svr_test_11.10.sk.zip -d popularity/pretrained_model/ && \
python3.6 -m pip install -r requirements.txt
python3.6 -m pip install -r requirements.txt && \
python3.6 -m pip install git+https://www.github.com/keras-team/keras-contrib.git
# Port to expose

View file

@ -32,7 +32,8 @@ class App extends React.Component {
result_comics: null,
framesMode: "0",
rlMode: "0",
imageAssessment: "0"
imageAssessment: "0",
styleTransferMode: "0",
};
this.onVideoDrop = this.onVideoDrop.bind(this);
this.onModelChange = this.onModelChange.bind(this);
@ -40,6 +41,7 @@ class App extends React.Component {
this.onYouTubeSubmit = this.onYouTubeSubmit.bind(this);
this.onSamplingChange = this.onSamplingChange.bind(this);
this.onImageAssessmentChange = this.onImageAssessmentChange.bind(this);
this.styleTransferChange = this.styleTransferChange.bind(this);
}
static onVideoUploadProgress(progressEvent) {
let percentCompleted = Math.round(
@ -52,6 +54,12 @@ class App extends React.Component {
this.setState({
rlMode: value
})
}
styleTransferChange(e) {
let value = e.currentTarget.value;
this.setState({
styleTransferMode: value
})
}
onSamplingChange(e) {
let value = e.currentTarget.value;
@ -78,12 +86,13 @@ class App extends React.Component {
}
}
processVideo(video) {
let { framesMode, rlMode, imageAssessment } = this.state;
let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
let data = new FormData();
data.append("file", video);
data.set('frames_mode', parseInt(framesMode));
data.set('rl_mode', parseInt(rlMode));
data.set("image_assessment_mode", parseInt(imageAssessment));
data.set('style_transfer_mode', parseInt(styleTransferMode));
post(COMIXIFY_API, data, {
headers: { "content-type": "multipart/form-data" },
onUploadProgress: App.onVideoUploadProgress
@ -111,12 +120,13 @@ class App extends React.Component {
this.processVideo(files[0]);
}
submitYouTube(link) {
let { framesMode, rlMode, imageAssessment } = this.state;
let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
post(FROM_YOUTUBE_API, {
url: link,
frames_mode: parseInt(framesMode),
rl_mode: parseInt(rlMode),
image_assessment_mode: parseInt(imageAssessment)
image_assessment_mode: parseInt(imageAssessment),
style_transfer_mode: parseInt(styleTransferMode)
})
.then(this.handleResponse)
.catch(err => {
@ -143,7 +153,7 @@ class App extends React.Component {
}
render() {
let {
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment, styleTransferMode
} = this.state;
let showUsage = [
App.appStates.INITIAL,
@ -231,6 +241,36 @@ class App extends React.Component {
onChange={this.onImageAssessmentChange}
/>
<label htmlFor="image-assessment-1">Popularity</label>
</div>
<div>
<span>Style Transfer model:</span>
<input
type="radio"
name="style-model"
id="style-model-0"
value="0"
checked={styleTransferMode === "0"}
onChange={this.styleTransferChange}
/>
<label htmlFor="style-model-0">ComixGAN</label>
<input
type="radio"
name="style-model"
id="style-model-1"
value="1"
checked={styleTransferMode === "1"}
onChange={this.styleTransferChange}
/>
<label htmlFor="style-model-1">CartoonGAN-Hayao</label>
<input
type="radio"
name="style-model"
id="style-model-2"
value="2"
checked={styleTransferMode === "2"}
onChange={this.styleTransferChange}
/>
<label htmlFor="style-model-2">CartoonGAN-Hosoda</label>
</div>
</div>
)}

File diff suppressed because one or more lines are too long

View file

@ -60,7 +60,7 @@
<i>* Max size of video is 50 MB</i>
<h4>Respose:</h4>
<h4>Response:</h4>
<code>
{
<br>

View file

@ -18,7 +18,7 @@ import logging
from utils import jj, profile
from keyframes_rl.models import DSN
from popularity.models import PopularityPredictor
from neural_image_assessment.models import NeuralImageAssessment
from neural_image_assessment.model import NeuralImageAssessment
from keyframes.kts import cpd_auto
from keyframes.utils import batch

View file

@ -1,27 +1,27 @@
import os
import errno
import os
import numpy as np
from keras.models import load_model
from keras.preprocessing.image import img_to_array
from keras.applications.nasnet import preprocess_input
import tensorflow as tf
from PIL import Image
MODEL_PATH = 'neural_image_assessment/pretrained_model/nima_model.h5'
from keras.applications.nasnet import preprocess_input
from keras.models import load_model
from keras.preprocessing.image import img_to_array
from django.conf import settings
class NeuralImageAssessment:
def __init__(self):
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), MODEL_PATH)
if not os.path.exists(settings.NIMA_MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.NIMA_MODEL_PATH)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.1
config.gpu_options.per_process_gpu_memory_fraction = 0.6
config.gpu_options.allow_growth = True
self.session = tf.Session(graph=self.graph, config=config)
with self.graph.as_default():
self.model = load_model(MODEL_PATH)
with self.session.as_default():
self.model = load_model(settings.NIMA_MODEL_PATH)
@staticmethod
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
@ -46,13 +46,14 @@ class NeuralImageAssessment:
def get_assessment_score(self, img_array):
with self.graph.as_default():
target_size = (224, 224)
img = NeuralImageAssessment.resize_image(img_array, target_size)
x = img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
with self.session.as_default():
target_size = (224, 224)
img = NeuralImageAssessment.resize_image(img_array, target_size)
x = img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
mean = NeuralImageAssessment.mean_score(scores)
return mean

View file

@ -72,6 +72,10 @@ echo 'export LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
source ~/.bashrc
# CHECK CUDNN VERSION
cat /usr/include/x86_64-linux-gnu/cudnn_v*.h | grep CUDNN_MAJOR -A 2
# INSTALL PACKAGES
export LC_ALL="en_US.UTF-8"
export LC_CTYPE="en_US.UTF-8"

View file

@ -1,16 +1,45 @@
astor==0.7.1
cloudpickle==0.6.1
cycler==0.10.0
dask==0.20.0
decorator==4.3.0
Django==2.0.7
django-rest-framework==0.1.0
djangorestframework==3.8.2
gast==0.2.0
grpcio==1.16.0
gunicorn==19.9.0
h5py==2.8.0
Keras==2.2.4
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
kiwisolver==1.0.1
Markdown==3.0.1
matplotlib==3.0.1
networkx==2.2
numpy==1.14.5
opencv-contrib-python==3.4.3.18
opencv-python==3.4.2.17
pafy==0.5.4
Pillow==5.2.0
protobuf==3.6.1
psycopg2==2.7.5
pyparsing==2.3.0
python-dateutil==2.7.5
pytz==2018.5
PyWavelets==1.0.1
PyYAML==3.13
scikit-image==0.14.1
scikit-learn==0.20.0
scipy==1.1.0
six==1.11.0
sklearn==0.0
tensorboard==1.10.0
tensorflow-gpu==1.10.1
termcolor==1.1.0
toolz==0.9.0
torch==0.4.1
torchvision==0.2.1
scikit-learn==0.19.2
youtube-dl==2018.9.18
Werkzeug==0.14.1
youtube-dl==2018.11.7
tensorflow-gpu==1.10.1

View file

@ -1,148 +1,153 @@
"""
Django settings for comixify project.
Generated by 'django-admin startproject' using Django 2.0.7.
For more information on this file, see
https://docs.djangoproject.com/en/2.0/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.0/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
# Use a separate file for the secret key
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
SECRET_KEY = f.read().strip()
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = os.environ.get('DEBUG') == 'true'
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'api',
'style_transfer',
'comic_layout',
'frontend',
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'settings.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
WSGI_APPLICATION = 'settings.wsgi.application'
# Database
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': 'postgres',
'USER': 'postgres',
'HOST': 'db',
'PORT': '5432',
'PASSWORD': 'postgres',
'CONN_MAX_AGE': 60,
}
}
# Password validation
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
CACHES = {
'default': {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
'LOCATION': 'unique-snowflake',
}
}
# Internationalization
# https://docs.djangoproject.com/en/2.0/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.0/howto/static-files/
STATIC_URL = '/static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
MAX_FILE_SIZE = 50000000
NUMBERS_OF_FRAMES_TO_SHOW = 10
TMP_DIR = 'tmp/'
GPU = True
FEATURE_BATCH_SIZE = 32
DEFAULT_FRAMES_SAMPLING_MODE = 0
DEFAULT_RL_MODE = 0
DEFAULT_IMAGE_ASSESSMENT_MODE = 0
DEBUG = True
"""
Django settings for comixify project.
Generated by 'django-admin startproject' using Django 2.0.7.
For more information on this file, see
https://docs.djangoproject.com/en/2.0/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.0/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
# Use a separate file for the secret key
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
SECRET_KEY = f.read().strip()
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = os.environ.get('DEBUG') == 'true'
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'api',
'style_transfer',
'comic_layout',
'frontend',
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'settings.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
WSGI_APPLICATION = 'settings.wsgi.application'
# Database
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': 'postgres',
'USER': 'postgres',
'HOST': 'db',
'PORT': '5432',
'PASSWORD': 'postgres',
'CONN_MAX_AGE': 60,
}
}
# Password validation
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
CACHES = {
'default': {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
'LOCATION': 'unique-snowflake',
}
}
# Internationalization
# https://docs.djangoproject.com/en/2.0/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.0/howto/static-files/
STATIC_URL = '/static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
MAX_FILE_SIZE = 50000000
NUMBERS_OF_FRAMES_TO_SHOW = 10
TMP_DIR = 'tmp/'
GPU = True
FEATURE_BATCH_SIZE = 32
DEFAULT_FRAMES_SAMPLING_MODE = 0
DEFAULT_RL_MODE = 0
DEFAULT_IMAGE_ASSESSMENT_MODE = 0
DEFAULT_STYLE_TRANSFER_MODE = 0
COMIX_GAN_MODEL_PATH = os.path.join(BASE_DIR, 'ComixGAN', 'pretrained_models', 'generator_model.h5')
MAX_FRAME_SIZE_FOR_STYLE_TRANSFER = 600
NIMA_MODEL_PATH = os.path.join(BASE_DIR, 'neural_image_assessment', 'pretrained_model', 'nima_model.h5')

View file

@ -9,21 +9,59 @@ from django.core.cache import cache
from torch.autograd import Variable
from CartoonGAN.network.Transformer import Transformer
from ComixGAN.model import ComixGAN
from utils import profile
# load pretrained model
comixGAN = ComixGAN()
class StyleTransfer():
@classmethod
@profile
def get_stylized_frames(cls, frames, method="cartoon_gan", gpu=settings.GPU, **kwargs):
if method == "cartoon_gan":
return cls._cartoon_gan_stylize(frames, gpu=gpu, **kwargs)
def get_stylized_frames(cls, frames, style_transfer_mode=0, gpu=settings.GPU):
if style_transfer_mode == 0:
return cls._comix_gan_stylize(frames=frames)
elif style_transfer_mode == 1:
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hayao')
elif style_transfer_mode == 2:
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hosoda')
@staticmethod
def _cartoon_gan_stylize(frames, gpu=True, **kwargs):
style = kwargs.get("style", "Hayao")
resize = kwargs.get("resize", 450)
def _resize_images(frames, size=384):
resized_images = []
for img in frames:
# resize image, keep aspect ratio
h, w, _ = img.shape
ratio = h / w
if ratio > 1:
h = size
w = int(h * 1.0 / ratio)
else:
w = size
h = int(w * ratio)
resized_img = cv2.resize(img, (w, h))
resized_images.append(resized_img)
return resized_images
@classmethod
def _comix_gan_stylize(cls, frames):
if max(frames[0].shape) > settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER:
frames = cls._resize_images(frames, size=settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER)
with comixGAN.graph.as_default():
with comixGAN.session.as_default():
batch_size = 4
stylized_imgs = []
for i in range(0, len(frames), batch_size):
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 1.25))
return list(np.concatenate(stylized_imgs, axis=0))
@classmethod
def _cartoon_gan_stylize(cls, frames, gpu=True, style='Hayao'):
model_cache_key = 'model_cache'
model = cache.get(model_cache_key) # get model from cache
@ -35,19 +73,10 @@ class StyleTransfer():
model.cuda() if gpu else model.float()
cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever
frames = cls._resize_images(frames, size=450)
stylized_imgs = []
for img in frames:
# resize image, keep aspect ratio
h, w, _ = img.shape
ratio = h * 1.0 / w
if ratio > 1:
h = resize
w = int(h * 1.0 / ratio)
else:
w = resize
h = int(w * ratio)
input_image = cv2.resize(img, (w, h))
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
input_image = transforms.ToTensor()(img).unsqueeze(0)
# preprocess, (-1, 1)
input_image = -1 + 2 * input_image
@ -64,6 +93,6 @@ class StyleTransfer():
output_image = np.rollaxis(output_image, 0, 3)
# append image to result images
stylized_imgs.append(output_image)
stylized_imgs.append(255 * output_image)
return stylized_imgs