Add comix gan (#16)

* Add ComixGAN #1

* Add minor fixes

* Reduce batch size

* Fix concatenate bug

* Fix GPU memory not released problem #2

* Build client

* Fix style_transfer_mode bug

* Improve timings

* Add minor fix with GPU name

* Fix tf session problem

* Compile frontend

* Change parameters

* Fix occasional  yt Error
This commit is contained in:
Maciej Pęśko 2018-11-10 22:45:55 +01:00 committed by GitHub
parent 75810f7c84
commit 1e5252b8f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 352 additions and 212 deletions

24
ComixGAN/model.py Normal file
View file

@ -0,0 +1,24 @@
import errno
import os
import tensorflow as tf
from django.conf import settings
from keras.models import load_model
from keras_contrib.layers import InstanceNormalization
class ComixGAN:
def __init__(self):
if not os.path.exists(settings.COMIX_GAN_MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.COMIX_GAN_MODEL_PATH)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.6
config.gpu_options.allow_growth = True
self.session = tf.Session(graph=self.graph, config=config)
with self.graph.as_default():
with self.session.as_default():
with tf.device('/device:GPU:0'):
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
custom_objects={'InstanceNormalization': InstanceNormalization})

Binary file not shown.

Binary file not shown.

View file

@ -23,7 +23,7 @@ class Video(models.Model):
yt_pafy = pafy.new(yt_url) yt_pafy = pafy.new(yt_url)
# Use the biggest possible quality with file size < MAX_FILE_SIZE and resolution <= 480px # Use the biggest possible quality with file size < MAX_FILE_SIZE and resolution <= 480px
for stream in yt_pafy.videostreams: for stream in reversed(yt_pafy.videostreams):
if stream.get_filesize() < settings.MAX_FILE_SIZE and int(stream.quality.split("x")[1]) <= 480: if stream.get_filesize() < settings.MAX_FILE_SIZE and int(stream.quality.split("x")[1]) <= 480:
tmp_name = uuid.uuid4().hex + ".mp4" tmp_name = uuid.uuid4().hex + ".mp4"
relative_path = jj('raw_videos', tmp_name) relative_path = jj('raw_videos', tmp_name)
@ -34,22 +34,24 @@ class Video(models.Model):
else: else:
raise TooLargeFile() raise TooLargeFile()
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0): def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0, style_transfer_mode=0):
(keyframes, keyframes_timings), keyframes_extraction_time = KeyFramesExtractor.get_keyframes( (keyframes, keyframes_timings), keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
video=self, video=self,
frames_mode=frames_mode, frames_mode=frames_mode,
rl_mode=rl_mode, rl_mode=rl_mode,
image_assessment_mode=image_assessment_mode image_assessment_mode=image_assessment_mode
) )
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes) stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes,
style_transfer_mode=style_transfer_mode)
comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes) comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes)
timings = { timings = {
'keyframes_extraction_time': keyframes_extraction_time, 'keyframes_extraction_time': keyframes_extraction_time,
'stylization_time': stylization_time, 'stylization_time': stylization_time,
'layout_generation_time': layout_generation_time, 'layout_generation_time': layout_generation_time,
**keyframes_timings 'keyframes_extraction_time_details': keyframes_timings
} }
return comic_image, timings return comic_image, timings
@ -61,7 +63,7 @@ class Comic(models.Model):
@profile @profile
def create_from_nparray(cls, nparray_file, video): def create_from_nparray(cls, nparray_file, video):
if nparray_file.max() <= 1: if nparray_file.max() <= 1:
nparray_file = (nparray_file * 255).astype(int) nparray_file = (nparray_file).astype(int)
tmp_name = uuid.uuid4().hex + ".png" tmp_name = uuid.uuid4().hex + ".png"
cv2.imwrite(jj(settings.TMP_DIR, tmp_name), nparray_file) cv2.imwrite(jj(settings.TMP_DIR, tmp_name), nparray_file)
with open(jj(settings.TMP_DIR, tmp_name), mode="rb") as tmp_file: with open(jj(settings.TMP_DIR, tmp_name), mode="rb") as tmp_file:

View file

@ -8,7 +8,10 @@ class VideoSerializer(serializers.Serializer):
file = serializers.FileField() file = serializers.FileField()
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE) frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE) rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE) image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
default=settings.DEFAULT_STYLE_TRANSFER_MODE)
def validate(self, attrs): def validate(self, attrs):
file = attrs.get("file") file = attrs.get("file")
@ -23,4 +26,7 @@ class YouTubeDownloadSerializer(serializers.Serializer):
url = serializers.URLField() url = serializers.URLField()
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE) frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE) rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE) image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
default=settings.DEFAULT_STYLE_TRANSFER_MODE)

View file

@ -22,7 +22,8 @@ class Comixify(APIView):
comic_image, timings = video.create_comic( comic_image, timings = video.create_comic(
frames_mode=serializer.validated_data["frames_mode"], frames_mode=serializer.validated_data["frames_mode"],
rl_mode=serializer.validated_data["rl_mode"], rl_mode=serializer.validated_data["rl_mode"],
image_assessment_mode=serializer.validated_data["image_assessment_mode"] image_assessment_mode=serializer.validated_data["image_assessment_mode"],
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
) )
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video) comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
timings['from_nparray_time'] = from_nparray_time timings['from_nparray_time'] = from_nparray_time
@ -54,7 +55,8 @@ class ComixifyFromYoutube(APIView):
comic_image, timings = video.create_comic( comic_image, timings = video.create_comic(
frames_mode=serializer.validated_data["frames_mode"], frames_mode=serializer.validated_data["frames_mode"],
rl_mode=serializer.validated_data["rl_mode"], rl_mode=serializer.validated_data["rl_mode"],
image_assessment_mode=serializer.validated_data["image_assessment_mode"] image_assessment_mode=serializer.validated_data["image_assessment_mode"],
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
) )
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video) comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
timings['from_nparray_time'] = from_nparray_time timings['from_nparray_time'] = from_nparray_time

View file

@ -28,6 +28,6 @@ class LayoutGenerator():
def _pad_images(frames): def _pad_images(frames):
padded_result_imgs = [] padded_result_imgs = []
for img in frames: for img in frames:
padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(1, 1, 1)) padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(255, 255, 255))
padded_result_imgs.append(padded_img) padded_result_imgs.append(padded_img)
return padded_result_imgs return padded_result_imgs

View file

@ -10,10 +10,7 @@ RUN apt-get update && apt-get install -y apt-utils software-properties-common &&
libsnappy-dev protobuf-compiler \ libsnappy-dev protobuf-compiler \
python-numpy python-setuptools python-scipy \ python-numpy python-setuptools python-scipy \
libavformat-dev libswscale-dev unzip && \ libavformat-dev libswscale-dev unzip && \
python3.6 -m pip install --upgrade pip && \ python3.6 -m pip install --upgrade pip
python3.6 -m pip install jupyter ipywidgets jupyterlab && \
python3.6 -m pip install h5py keras && \
python3.6 -m pip install scikit-image opencv-contrib-python pyyaml
RUN mkdir /comixify RUN mkdir /comixify
COPY ./Makefile.config /comixify/Makefile.config COPY ./Makefile.config /comixify/Makefile.config
@ -45,7 +42,8 @@ RUN echo "$CAFFE_ROOT/build/lib" >> /etc/ld.so.conf.d/caffe.conf && ldconfig &&
WORKDIR /comixify WORKDIR /comixify
COPY . /comixify COPY . /comixify
RUN unzip popularity/pretrained_model/svr_test_11.10.sk.zip -d popularity/pretrained_model/ && \ RUN unzip popularity/pretrained_model/svr_test_11.10.sk.zip -d popularity/pretrained_model/ && \
python3.6 -m pip install -r requirements.txt python3.6 -m pip install -r requirements.txt && \
python3.6 -m pip install git+https://www.github.com/keras-team/keras-contrib.git
# Port to expose # Port to expose

View file

@ -32,7 +32,8 @@ class App extends React.Component {
result_comics: null, result_comics: null,
framesMode: "0", framesMode: "0",
rlMode: "0", rlMode: "0",
imageAssessment: "0" imageAssessment: "0",
styleTransferMode: "0",
}; };
this.onVideoDrop = this.onVideoDrop.bind(this); this.onVideoDrop = this.onVideoDrop.bind(this);
this.onModelChange = this.onModelChange.bind(this); this.onModelChange = this.onModelChange.bind(this);
@ -40,6 +41,7 @@ class App extends React.Component {
this.onYouTubeSubmit = this.onYouTubeSubmit.bind(this); this.onYouTubeSubmit = this.onYouTubeSubmit.bind(this);
this.onSamplingChange = this.onSamplingChange.bind(this); this.onSamplingChange = this.onSamplingChange.bind(this);
this.onImageAssessmentChange = this.onImageAssessmentChange.bind(this); this.onImageAssessmentChange = this.onImageAssessmentChange.bind(this);
this.styleTransferChange = this.styleTransferChange.bind(this);
} }
static onVideoUploadProgress(progressEvent) { static onVideoUploadProgress(progressEvent) {
let percentCompleted = Math.round( let percentCompleted = Math.round(
@ -52,6 +54,12 @@ class App extends React.Component {
this.setState({ this.setState({
rlMode: value rlMode: value
}) })
}
styleTransferChange(e) {
let value = e.currentTarget.value;
this.setState({
styleTransferMode: value
})
} }
onSamplingChange(e) { onSamplingChange(e) {
let value = e.currentTarget.value; let value = e.currentTarget.value;
@ -78,12 +86,13 @@ class App extends React.Component {
} }
} }
processVideo(video) { processVideo(video) {
let { framesMode, rlMode, imageAssessment } = this.state; let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
let data = new FormData(); let data = new FormData();
data.append("file", video); data.append("file", video);
data.set('frames_mode', parseInt(framesMode)); data.set('frames_mode', parseInt(framesMode));
data.set('rl_mode', parseInt(rlMode)); data.set('rl_mode', parseInt(rlMode));
data.set("image_assessment_mode", parseInt(imageAssessment)); data.set("image_assessment_mode", parseInt(imageAssessment));
data.set('style_transfer_mode', parseInt(styleTransferMode));
post(COMIXIFY_API, data, { post(COMIXIFY_API, data, {
headers: { "content-type": "multipart/form-data" }, headers: { "content-type": "multipart/form-data" },
onUploadProgress: App.onVideoUploadProgress onUploadProgress: App.onVideoUploadProgress
@ -111,12 +120,13 @@ class App extends React.Component {
this.processVideo(files[0]); this.processVideo(files[0]);
} }
submitYouTube(link) { submitYouTube(link) {
let { framesMode, rlMode, imageAssessment } = this.state; let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
post(FROM_YOUTUBE_API, { post(FROM_YOUTUBE_API, {
url: link, url: link,
frames_mode: parseInt(framesMode), frames_mode: parseInt(framesMode),
rl_mode: parseInt(rlMode), rl_mode: parseInt(rlMode),
image_assessment_mode: parseInt(imageAssessment) image_assessment_mode: parseInt(imageAssessment),
style_transfer_mode: parseInt(styleTransferMode)
}) })
.then(this.handleResponse) .then(this.handleResponse)
.catch(err => { .catch(err => {
@ -143,7 +153,7 @@ class App extends React.Component {
} }
render() { render() {
let { let {
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment, styleTransferMode
} = this.state; } = this.state;
let showUsage = [ let showUsage = [
App.appStates.INITIAL, App.appStates.INITIAL,
@ -231,6 +241,36 @@ class App extends React.Component {
onChange={this.onImageAssessmentChange} onChange={this.onImageAssessmentChange}
/> />
<label htmlFor="image-assessment-1">Popularity</label> <label htmlFor="image-assessment-1">Popularity</label>
</div>
<div>
<span>Style Transfer model:</span>
<input
type="radio"
name="style-model"
id="style-model-0"
value="0"
checked={styleTransferMode === "0"}
onChange={this.styleTransferChange}
/>
<label htmlFor="style-model-0">ComixGAN</label>
<input
type="radio"
name="style-model"
id="style-model-1"
value="1"
checked={styleTransferMode === "1"}
onChange={this.styleTransferChange}
/>
<label htmlFor="style-model-1">CartoonGAN-Hayao</label>
<input
type="radio"
name="style-model"
id="style-model-2"
value="2"
checked={styleTransferMode === "2"}
onChange={this.styleTransferChange}
/>
<label htmlFor="style-model-2">CartoonGAN-Hosoda</label>
</div> </div>
</div> </div>
)} )}

File diff suppressed because one or more lines are too long

View file

@ -60,7 +60,7 @@
<i>* Max size of video is 50 MB</i> <i>* Max size of video is 50 MB</i>
<h4>Respose:</h4> <h4>Response:</h4>
<code> <code>
{ {
<br> <br>

View file

@ -18,7 +18,7 @@ import logging
from utils import jj, profile from utils import jj, profile
from keyframes_rl.models import DSN from keyframes_rl.models import DSN
from popularity.models import PopularityPredictor from popularity.models import PopularityPredictor
from neural_image_assessment.models import NeuralImageAssessment from neural_image_assessment.model import NeuralImageAssessment
from keyframes.kts import cpd_auto from keyframes.kts import cpd_auto
from keyframes.utils import batch from keyframes.utils import batch

View file

@ -1,27 +1,27 @@
import os
import errno import errno
import os
import numpy as np import numpy as np
from keras.models import load_model
from keras.preprocessing.image import img_to_array
from keras.applications.nasnet import preprocess_input
import tensorflow as tf import tensorflow as tf
from PIL import Image from PIL import Image
from keras.applications.nasnet import preprocess_input
from keras.models import load_model
MODEL_PATH = 'neural_image_assessment/pretrained_model/nima_model.h5' from keras.preprocessing.image import img_to_array
from django.conf import settings
class NeuralImageAssessment: class NeuralImageAssessment:
def __init__(self): def __init__(self):
if not os.path.exists(MODEL_PATH): if not os.path.exists(settings.NIMA_MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), MODEL_PATH) raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.NIMA_MODEL_PATH)
self.graph = tf.Graph() self.graph = tf.Graph()
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.1 config.gpu_options.per_process_gpu_memory_fraction = 0.6
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
self.session = tf.Session(graph=self.graph, config=config) self.session = tf.Session(graph=self.graph, config=config)
with self.graph.as_default(): with self.graph.as_default():
self.model = load_model(MODEL_PATH) with self.session.as_default():
self.model = load_model(settings.NIMA_MODEL_PATH)
@staticmethod @staticmethod
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'): def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
@ -46,13 +46,14 @@ class NeuralImageAssessment:
def get_assessment_score(self, img_array): def get_assessment_score(self, img_array):
with self.graph.as_default(): with self.graph.as_default():
target_size = (224, 224) with self.session.as_default():
img = NeuralImageAssessment.resize_image(img_array, target_size) target_size = (224, 224)
x = img_to_array(img) img = NeuralImageAssessment.resize_image(img_array, target_size)
x = np.expand_dims(x, axis=0) x = img_to_array(img)
x = preprocess_input(x) x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
scores = self.model.predict(x, batch_size=1, verbose=0)[0] scores = self.model.predict(x, batch_size=1, verbose=0)[0]
mean = NeuralImageAssessment.mean_score(scores) mean = NeuralImageAssessment.mean_score(scores)
return mean return mean

View file

@ -72,6 +72,10 @@ echo 'export LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
source ~/.bashrc source ~/.bashrc
# CHECK CUDNN VERSION
cat /usr/include/x86_64-linux-gnu/cudnn_v*.h | grep CUDNN_MAJOR -A 2
# INSTALL PACKAGES # INSTALL PACKAGES
export LC_ALL="en_US.UTF-8" export LC_ALL="en_US.UTF-8"
export LC_CTYPE="en_US.UTF-8" export LC_CTYPE="en_US.UTF-8"

View file

@ -1,16 +1,45 @@
astor==0.7.1
cloudpickle==0.6.1
cycler==0.10.0
dask==0.20.0
decorator==4.3.0
Django==2.0.7 Django==2.0.7
django-rest-framework==0.1.0 django-rest-framework==0.1.0
djangorestframework==3.8.2 djangorestframework==3.8.2
gast==0.2.0
grpcio==1.16.0
gunicorn==19.9.0 gunicorn==19.9.0
h5py==2.8.0
Keras==2.2.4
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
kiwisolver==1.0.1
Markdown==3.0.1
matplotlib==3.0.1
networkx==2.2
numpy==1.14.5 numpy==1.14.5
opencv-contrib-python==3.4.3.18
opencv-python==3.4.2.17 opencv-python==3.4.2.17
pafy==0.5.4 pafy==0.5.4
Pillow==5.2.0 Pillow==5.2.0
protobuf==3.6.1
psycopg2==2.7.5 psycopg2==2.7.5
pyparsing==2.3.0
python-dateutil==2.7.5
pytz==2018.5 pytz==2018.5
PyWavelets==1.0.1
PyYAML==3.13
scikit-image==0.14.1
scikit-learn==0.20.0
scipy==1.1.0
six==1.11.0 six==1.11.0
sklearn==0.0
tensorboard==1.10.0
tensorflow-gpu==1.10.1
termcolor==1.1.0
toolz==0.9.0
torch==0.4.1 torch==0.4.1
torchvision==0.2.1 torchvision==0.2.1
scikit-learn==0.19.2 Werkzeug==0.14.1
youtube-dl==2018.9.18 youtube-dl==2018.11.7
tensorflow-gpu==1.10.1 tensorflow-gpu==1.10.1

View file

@ -1,148 +1,153 @@
""" """
Django settings for comixify project. Django settings for comixify project.
Generated by 'django-admin startproject' using Django 2.0.7. Generated by 'django-admin startproject' using Django 2.0.7.
For more information on this file, see For more information on this file, see
https://docs.djangoproject.com/en/2.0/topics/settings/ https://docs.djangoproject.com/en/2.0/topics/settings/
For the full list of settings and their values, see For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.0/ref/settings/ https://docs.djangoproject.com/en/2.0/ref/settings/
""" """
import os import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...) # Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production # Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/ # See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret! # SECURITY WARNING: keep the secret key used in production secret!
# Use a separate file for the secret key # Use a separate file for the secret key
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f: with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
SECRET_KEY = f.read().strip() SECRET_KEY = f.read().strip()
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = os.environ.get('DEBUG') == 'true' DEBUG = os.environ.get('DEBUG') == 'true'
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1'] ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
# Application definition # Application definition
INSTALLED_APPS = [ INSTALLED_APPS = [
'django.contrib.admin', 'django.contrib.admin',
'django.contrib.auth', 'django.contrib.auth',
'django.contrib.contenttypes', 'django.contrib.contenttypes',
'django.contrib.sessions', 'django.contrib.sessions',
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'rest_framework', 'rest_framework',
'api', 'api',
'style_transfer', 'style_transfer',
'comic_layout', 'comic_layout',
'frontend', 'frontend',
] ]
MIDDLEWARE = [ MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware', 'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware',
] ]
ROOT_URLCONF = 'settings.urls' ROOT_URLCONF = 'settings.urls'
TEMPLATES = [ TEMPLATES = [
{ {
'BACKEND': 'django.template.backends.django.DjangoTemplates', 'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [], 'DIRS': [],
'APP_DIRS': True, 'APP_DIRS': True,
'OPTIONS': { 'OPTIONS': {
'context_processors': [ 'context_processors': [
'django.template.context_processors.debug', 'django.template.context_processors.debug',
'django.template.context_processors.request', 'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth', 'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages', 'django.contrib.messages.context_processors.messages',
], ],
}, },
}, },
] ]
WSGI_APPLICATION = 'settings.wsgi.application' WSGI_APPLICATION = 'settings.wsgi.application'
# Database # Database
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases # https://docs.djangoproject.com/en/2.0/ref/settings/#databases
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2', 'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': 'postgres', 'NAME': 'postgres',
'USER': 'postgres', 'USER': 'postgres',
'HOST': 'db', 'HOST': 'db',
'PORT': '5432', 'PORT': '5432',
'PASSWORD': 'postgres', 'PASSWORD': 'postgres',
'CONN_MAX_AGE': 60, 'CONN_MAX_AGE': 60,
} }
} }
# Password validation # Password validation
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators # https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [ AUTH_PASSWORD_VALIDATORS = [
{ {
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
}, },
] ]
CACHES = { CACHES = {
'default': { 'default': {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
'LOCATION': 'unique-snowflake', 'LOCATION': 'unique-snowflake',
} }
} }
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/2.0/topics/i18n/ # https://docs.djangoproject.com/en/2.0/topics/i18n/
LANGUAGE_CODE = 'en-us' LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC' TIME_ZONE = 'UTC'
USE_I18N = True USE_I18N = True
USE_L10N = True USE_L10N = True
USE_TZ = True USE_TZ = True
# Static files (CSS, JavaScript, Images) # Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.0/howto/static-files/ # https://docs.djangoproject.com/en/2.0/howto/static-files/
STATIC_URL = '/static/' STATIC_URL = '/static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static') STATIC_ROOT = os.path.join(BASE_DIR, 'static')
MEDIA_URL = '/media/' MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media') MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi'] PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
MAX_FILE_SIZE = 50000000 MAX_FILE_SIZE = 50000000
NUMBERS_OF_FRAMES_TO_SHOW = 10 NUMBERS_OF_FRAMES_TO_SHOW = 10
TMP_DIR = 'tmp/' TMP_DIR = 'tmp/'
GPU = True GPU = True
FEATURE_BATCH_SIZE = 32 FEATURE_BATCH_SIZE = 32
DEFAULT_FRAMES_SAMPLING_MODE = 0 DEFAULT_FRAMES_SAMPLING_MODE = 0
DEFAULT_RL_MODE = 0 DEFAULT_RL_MODE = 0
DEFAULT_IMAGE_ASSESSMENT_MODE = 0 DEFAULT_IMAGE_ASSESSMENT_MODE = 0
DEBUG = True
DEFAULT_STYLE_TRANSFER_MODE = 0
COMIX_GAN_MODEL_PATH = os.path.join(BASE_DIR, 'ComixGAN', 'pretrained_models', 'generator_model.h5')
MAX_FRAME_SIZE_FOR_STYLE_TRANSFER = 600
NIMA_MODEL_PATH = os.path.join(BASE_DIR, 'neural_image_assessment', 'pretrained_model', 'nima_model.h5')

View file

@ -9,21 +9,59 @@ from django.core.cache import cache
from torch.autograd import Variable from torch.autograd import Variable
from CartoonGAN.network.Transformer import Transformer from CartoonGAN.network.Transformer import Transformer
from ComixGAN.model import ComixGAN
from utils import profile from utils import profile
# load pretrained model
comixGAN = ComixGAN()
class StyleTransfer(): class StyleTransfer():
@classmethod @classmethod
@profile @profile
def get_stylized_frames(cls, frames, method="cartoon_gan", gpu=settings.GPU, **kwargs): def get_stylized_frames(cls, frames, style_transfer_mode=0, gpu=settings.GPU):
if method == "cartoon_gan": if style_transfer_mode == 0:
return cls._cartoon_gan_stylize(frames, gpu=gpu, **kwargs) return cls._comix_gan_stylize(frames=frames)
elif style_transfer_mode == 1:
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hayao')
elif style_transfer_mode == 2:
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hosoda')
@staticmethod @staticmethod
def _cartoon_gan_stylize(frames, gpu=True, **kwargs): def _resize_images(frames, size=384):
style = kwargs.get("style", "Hayao") resized_images = []
resize = kwargs.get("resize", 450) for img in frames:
# resize image, keep aspect ratio
h, w, _ = img.shape
ratio = h / w
if ratio > 1:
h = size
w = int(h * 1.0 / ratio)
else:
w = size
h = int(w * ratio)
resized_img = cv2.resize(img, (w, h))
resized_images.append(resized_img)
return resized_images
@classmethod
def _comix_gan_stylize(cls, frames):
if max(frames[0].shape) > settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER:
frames = cls._resize_images(frames, size=settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER)
with comixGAN.graph.as_default():
with comixGAN.session.as_default():
batch_size = 4
stylized_imgs = []
for i in range(0, len(frames), batch_size):
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 1.25))
return list(np.concatenate(stylized_imgs, axis=0))
@classmethod
def _cartoon_gan_stylize(cls, frames, gpu=True, style='Hayao'):
model_cache_key = 'model_cache' model_cache_key = 'model_cache'
model = cache.get(model_cache_key) # get model from cache model = cache.get(model_cache_key) # get model from cache
@ -35,19 +73,10 @@ class StyleTransfer():
model.cuda() if gpu else model.float() model.cuda() if gpu else model.float()
cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever
frames = cls._resize_images(frames, size=450)
stylized_imgs = [] stylized_imgs = []
for img in frames: for img in frames:
# resize image, keep aspect ratio input_image = transforms.ToTensor()(img).unsqueeze(0)
h, w, _ = img.shape
ratio = h * 1.0 / w
if ratio > 1:
h = resize
w = int(h * 1.0 / ratio)
else:
w = resize
h = int(w * ratio)
input_image = cv2.resize(img, (w, h))
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
# preprocess, (-1, 1) # preprocess, (-1, 1)
input_image = -1 + 2 * input_image input_image = -1 + 2 * input_image
@ -64,6 +93,6 @@ class StyleTransfer():
output_image = np.rollaxis(output_image, 0, 3) output_image = np.rollaxis(output_image, 0, 3)
# append image to result images # append image to result images
stylized_imgs.append(output_image) stylized_imgs.append(255 * output_image)
return stylized_imgs return stylized_imgs