forked from prehistoric-systems/comixify
Add comix gan (#16)
* Add ComixGAN #1 * Add minor fixes * Reduce batch size * Fix concatenate bug * Fix GPU memory not released problem #2 * Build client * Fix style_transfer_mode bug * Improve timings * Add minor fix with GPU name * Fix tf session problem * Compile frontend * Change parameters * Fix occasional yt Error
This commit is contained in:
parent
75810f7c84
commit
1e5252b8f0
17 changed files with 352 additions and 212 deletions
24
ComixGAN/model.py
Normal file
24
ComixGAN/model.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import errno
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
from django.conf import settings
|
||||
from keras.models import load_model
|
||||
from keras_contrib.layers import InstanceNormalization
|
||||
|
||||
|
||||
class ComixGAN:
|
||||
def __init__(self):
|
||||
if not os.path.exists(settings.COMIX_GAN_MODEL_PATH):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.COMIX_GAN_MODEL_PATH)
|
||||
self.graph = tf.Graph()
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.6
|
||||
config.gpu_options.allow_growth = True
|
||||
self.session = tf.Session(graph=self.graph, config=config)
|
||||
with self.graph.as_default():
|
||||
with self.session.as_default():
|
||||
with tf.device('/device:GPU:0'):
|
||||
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
|
||||
custom_objects={'InstanceNormalization': InstanceNormalization})
|
||||
|
||||
BIN
ComixGAN/pretrained_models/generator_model.h5
Normal file
BIN
ComixGAN/pretrained_models/generator_model.h5
Normal file
Binary file not shown.
BIN
ComixGAN/pretrained_models/generator_model2.h5
Normal file
BIN
ComixGAN/pretrained_models/generator_model2.h5
Normal file
Binary file not shown.
|
|
@ -23,7 +23,7 @@ class Video(models.Model):
|
|||
yt_pafy = pafy.new(yt_url)
|
||||
|
||||
# Use the biggest possible quality with file size < MAX_FILE_SIZE and resolution <= 480px
|
||||
for stream in yt_pafy.videostreams:
|
||||
for stream in reversed(yt_pafy.videostreams):
|
||||
if stream.get_filesize() < settings.MAX_FILE_SIZE and int(stream.quality.split("x")[1]) <= 480:
|
||||
tmp_name = uuid.uuid4().hex + ".mp4"
|
||||
relative_path = jj('raw_videos', tmp_name)
|
||||
|
|
@ -34,22 +34,24 @@ class Video(models.Model):
|
|||
else:
|
||||
raise TooLargeFile()
|
||||
|
||||
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0):
|
||||
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0, style_transfer_mode=0):
|
||||
(keyframes, keyframes_timings), keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
|
||||
video=self,
|
||||
frames_mode=frames_mode,
|
||||
rl_mode=rl_mode,
|
||||
image_assessment_mode=image_assessment_mode
|
||||
)
|
||||
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes)
|
||||
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes,
|
||||
style_transfer_mode=style_transfer_mode)
|
||||
comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes)
|
||||
|
||||
timings = {
|
||||
'keyframes_extraction_time': keyframes_extraction_time,
|
||||
'stylization_time': stylization_time,
|
||||
'layout_generation_time': layout_generation_time,
|
||||
**keyframes_timings
|
||||
'keyframes_extraction_time_details': keyframes_timings
|
||||
}
|
||||
|
||||
return comic_image, timings
|
||||
|
||||
|
||||
|
|
@ -61,7 +63,7 @@ class Comic(models.Model):
|
|||
@profile
|
||||
def create_from_nparray(cls, nparray_file, video):
|
||||
if nparray_file.max() <= 1:
|
||||
nparray_file = (nparray_file * 255).astype(int)
|
||||
nparray_file = (nparray_file).astype(int)
|
||||
tmp_name = uuid.uuid4().hex + ".png"
|
||||
cv2.imwrite(jj(settings.TMP_DIR, tmp_name), nparray_file)
|
||||
with open(jj(settings.TMP_DIR, tmp_name), mode="rb") as tmp_file:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ class VideoSerializer(serializers.Serializer):
|
|||
file = serializers.FileField()
|
||||
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
|
||||
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
|
||||
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
||||
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
|
||||
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
||||
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
|
||||
default=settings.DEFAULT_STYLE_TRANSFER_MODE)
|
||||
|
||||
def validate(self, attrs):
|
||||
file = attrs.get("file")
|
||||
|
|
@ -23,4 +26,7 @@ class YouTubeDownloadSerializer(serializers.Serializer):
|
|||
url = serializers.URLField()
|
||||
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
|
||||
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
|
||||
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
||||
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
|
||||
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
||||
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
|
||||
default=settings.DEFAULT_STYLE_TRANSFER_MODE)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ class Comixify(APIView):
|
|||
comic_image, timings = video.create_comic(
|
||||
frames_mode=serializer.validated_data["frames_mode"],
|
||||
rl_mode=serializer.validated_data["rl_mode"],
|
||||
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
|
||||
image_assessment_mode=serializer.validated_data["image_assessment_mode"],
|
||||
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
|
||||
)
|
||||
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
|
||||
timings['from_nparray_time'] = from_nparray_time
|
||||
|
|
@ -54,7 +55,8 @@ class ComixifyFromYoutube(APIView):
|
|||
comic_image, timings = video.create_comic(
|
||||
frames_mode=serializer.validated_data["frames_mode"],
|
||||
rl_mode=serializer.validated_data["rl_mode"],
|
||||
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
|
||||
image_assessment_mode=serializer.validated_data["image_assessment_mode"],
|
||||
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
|
||||
)
|
||||
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
|
||||
timings['from_nparray_time'] = from_nparray_time
|
||||
|
|
|
|||
|
|
@ -28,6 +28,6 @@ class LayoutGenerator():
|
|||
def _pad_images(frames):
|
||||
padded_result_imgs = []
|
||||
for img in frames:
|
||||
padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(1, 1, 1))
|
||||
padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(255, 255, 255))
|
||||
padded_result_imgs.append(padded_img)
|
||||
return padded_result_imgs
|
||||
|
|
|
|||
|
|
@ -10,10 +10,7 @@ RUN apt-get update && apt-get install -y apt-utils software-properties-common &&
|
|||
libsnappy-dev protobuf-compiler \
|
||||
python-numpy python-setuptools python-scipy \
|
||||
libavformat-dev libswscale-dev unzip && \
|
||||
python3.6 -m pip install --upgrade pip && \
|
||||
python3.6 -m pip install jupyter ipywidgets jupyterlab && \
|
||||
python3.6 -m pip install h5py keras && \
|
||||
python3.6 -m pip install scikit-image opencv-contrib-python pyyaml
|
||||
python3.6 -m pip install --upgrade pip
|
||||
|
||||
RUN mkdir /comixify
|
||||
COPY ./Makefile.config /comixify/Makefile.config
|
||||
|
|
@ -45,7 +42,8 @@ RUN echo "$CAFFE_ROOT/build/lib" >> /etc/ld.so.conf.d/caffe.conf && ldconfig &&
|
|||
WORKDIR /comixify
|
||||
COPY . /comixify
|
||||
RUN unzip popularity/pretrained_model/svr_test_11.10.sk.zip -d popularity/pretrained_model/ && \
|
||||
python3.6 -m pip install -r requirements.txt
|
||||
python3.6 -m pip install -r requirements.txt && \
|
||||
python3.6 -m pip install git+https://www.github.com/keras-team/keras-contrib.git
|
||||
|
||||
|
||||
# Port to expose
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ class App extends React.Component {
|
|||
result_comics: null,
|
||||
framesMode: "0",
|
||||
rlMode: "0",
|
||||
imageAssessment: "0"
|
||||
imageAssessment: "0",
|
||||
styleTransferMode: "0",
|
||||
};
|
||||
this.onVideoDrop = this.onVideoDrop.bind(this);
|
||||
this.onModelChange = this.onModelChange.bind(this);
|
||||
|
|
@ -40,6 +41,7 @@ class App extends React.Component {
|
|||
this.onYouTubeSubmit = this.onYouTubeSubmit.bind(this);
|
||||
this.onSamplingChange = this.onSamplingChange.bind(this);
|
||||
this.onImageAssessmentChange = this.onImageAssessmentChange.bind(this);
|
||||
this.styleTransferChange = this.styleTransferChange.bind(this);
|
||||
}
|
||||
static onVideoUploadProgress(progressEvent) {
|
||||
let percentCompleted = Math.round(
|
||||
|
|
@ -52,6 +54,12 @@ class App extends React.Component {
|
|||
this.setState({
|
||||
rlMode: value
|
||||
})
|
||||
}
|
||||
styleTransferChange(e) {
|
||||
let value = e.currentTarget.value;
|
||||
this.setState({
|
||||
styleTransferMode: value
|
||||
})
|
||||
}
|
||||
onSamplingChange(e) {
|
||||
let value = e.currentTarget.value;
|
||||
|
|
@ -78,12 +86,13 @@ class App extends React.Component {
|
|||
}
|
||||
}
|
||||
processVideo(video) {
|
||||
let { framesMode, rlMode, imageAssessment } = this.state;
|
||||
let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
|
||||
let data = new FormData();
|
||||
data.append("file", video);
|
||||
data.set('frames_mode', parseInt(framesMode));
|
||||
data.set('rl_mode', parseInt(rlMode));
|
||||
data.set("image_assessment_mode", parseInt(imageAssessment));
|
||||
data.set('style_transfer_mode', parseInt(styleTransferMode));
|
||||
post(COMIXIFY_API, data, {
|
||||
headers: { "content-type": "multipart/form-data" },
|
||||
onUploadProgress: App.onVideoUploadProgress
|
||||
|
|
@ -111,12 +120,13 @@ class App extends React.Component {
|
|||
this.processVideo(files[0]);
|
||||
}
|
||||
submitYouTube(link) {
|
||||
let { framesMode, rlMode, imageAssessment } = this.state;
|
||||
let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
|
||||
post(FROM_YOUTUBE_API, {
|
||||
url: link,
|
||||
frames_mode: parseInt(framesMode),
|
||||
rl_mode: parseInt(rlMode),
|
||||
image_assessment_mode: parseInt(imageAssessment)
|
||||
image_assessment_mode: parseInt(imageAssessment),
|
||||
style_transfer_mode: parseInt(styleTransferMode)
|
||||
})
|
||||
.then(this.handleResponse)
|
||||
.catch(err => {
|
||||
|
|
@ -143,7 +153,7 @@ class App extends React.Component {
|
|||
}
|
||||
render() {
|
||||
let {
|
||||
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment
|
||||
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment, styleTransferMode
|
||||
} = this.state;
|
||||
let showUsage = [
|
||||
App.appStates.INITIAL,
|
||||
|
|
@ -231,6 +241,36 @@ class App extends React.Component {
|
|||
onChange={this.onImageAssessmentChange}
|
||||
/>
|
||||
<label htmlFor="image-assessment-1">Popularity</label>
|
||||
</div>
|
||||
<div>
|
||||
<span>Style Transfer model:</span>
|
||||
<input
|
||||
type="radio"
|
||||
name="style-model"
|
||||
id="style-model-0"
|
||||
value="0"
|
||||
checked={styleTransferMode === "0"}
|
||||
onChange={this.styleTransferChange}
|
||||
/>
|
||||
<label htmlFor="style-model-0">ComixGAN</label>
|
||||
<input
|
||||
type="radio"
|
||||
name="style-model"
|
||||
id="style-model-1"
|
||||
value="1"
|
||||
checked={styleTransferMode === "1"}
|
||||
onChange={this.styleTransferChange}
|
||||
/>
|
||||
<label htmlFor="style-model-1">CartoonGAN-Hayao</label>
|
||||
<input
|
||||
type="radio"
|
||||
name="style-model"
|
||||
id="style-model-2"
|
||||
value="2"
|
||||
checked={styleTransferMode === "2"}
|
||||
onChange={this.styleTransferChange}
|
||||
/>
|
||||
<label htmlFor="style-model-2">CartoonGAN-Hosoda</label>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -60,7 +60,7 @@
|
|||
|
||||
<i>* Max size of video is 50 MB</i>
|
||||
|
||||
<h4>Respose:</h4>
|
||||
<h4>Response:</h4>
|
||||
<code>
|
||||
{
|
||||
<br>
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import logging
|
|||
from utils import jj, profile
|
||||
from keyframes_rl.models import DSN
|
||||
from popularity.models import PopularityPredictor
|
||||
from neural_image_assessment.models import NeuralImageAssessment
|
||||
from neural_image_assessment.model import NeuralImageAssessment
|
||||
from keyframes.kts import cpd_auto
|
||||
from keyframes.utils import batch
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,27 @@
|
|||
import os
|
||||
import errno
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from keras.models import load_model
|
||||
from keras.preprocessing.image import img_to_array
|
||||
from keras.applications.nasnet import preprocess_input
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
|
||||
|
||||
MODEL_PATH = 'neural_image_assessment/pretrained_model/nima_model.h5'
|
||||
from keras.applications.nasnet import preprocess_input
|
||||
from keras.models import load_model
|
||||
from keras.preprocessing.image import img_to_array
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class NeuralImageAssessment:
|
||||
def __init__(self):
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), MODEL_PATH)
|
||||
if not os.path.exists(settings.NIMA_MODEL_PATH):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.NIMA_MODEL_PATH)
|
||||
self.graph = tf.Graph()
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.1
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.6
|
||||
config.gpu_options.allow_growth = True
|
||||
self.session = tf.Session(graph=self.graph, config=config)
|
||||
with self.graph.as_default():
|
||||
self.model = load_model(MODEL_PATH)
|
||||
with self.session.as_default():
|
||||
self.model = load_model(settings.NIMA_MODEL_PATH)
|
||||
|
||||
@staticmethod
|
||||
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
|
||||
|
|
@ -46,13 +46,14 @@ class NeuralImageAssessment:
|
|||
|
||||
def get_assessment_score(self, img_array):
|
||||
with self.graph.as_default():
|
||||
target_size = (224, 224)
|
||||
img = NeuralImageAssessment.resize_image(img_array, target_size)
|
||||
x = img_to_array(img)
|
||||
x = np.expand_dims(x, axis=0)
|
||||
x = preprocess_input(x)
|
||||
with self.session.as_default():
|
||||
target_size = (224, 224)
|
||||
img = NeuralImageAssessment.resize_image(img_array, target_size)
|
||||
x = img_to_array(img)
|
||||
x = np.expand_dims(x, axis=0)
|
||||
x = preprocess_input(x)
|
||||
|
||||
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
|
||||
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
|
||||
mean = NeuralImageAssessment.mean_score(scores)
|
||||
|
||||
return mean
|
||||
|
|
@ -72,6 +72,10 @@ echo 'export LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
|
|||
source ~/.bashrc
|
||||
|
||||
|
||||
# CHECK CUDNN VERSION
|
||||
cat /usr/include/x86_64-linux-gnu/cudnn_v*.h | grep CUDNN_MAJOR -A 2
|
||||
|
||||
|
||||
# INSTALL PACKAGES
|
||||
export LC_ALL="en_US.UTF-8"
|
||||
export LC_CTYPE="en_US.UTF-8"
|
||||
|
|
|
|||
|
|
@ -1,16 +1,45 @@
|
|||
astor==0.7.1
|
||||
cloudpickle==0.6.1
|
||||
cycler==0.10.0
|
||||
dask==0.20.0
|
||||
decorator==4.3.0
|
||||
Django==2.0.7
|
||||
django-rest-framework==0.1.0
|
||||
djangorestframework==3.8.2
|
||||
gast==0.2.0
|
||||
grpcio==1.16.0
|
||||
gunicorn==19.9.0
|
||||
h5py==2.8.0
|
||||
Keras==2.2.4
|
||||
Keras-Applications==1.0.6
|
||||
Keras-Preprocessing==1.0.5
|
||||
kiwisolver==1.0.1
|
||||
Markdown==3.0.1
|
||||
matplotlib==3.0.1
|
||||
networkx==2.2
|
||||
numpy==1.14.5
|
||||
opencv-contrib-python==3.4.3.18
|
||||
opencv-python==3.4.2.17
|
||||
pafy==0.5.4
|
||||
Pillow==5.2.0
|
||||
protobuf==3.6.1
|
||||
psycopg2==2.7.5
|
||||
pyparsing==2.3.0
|
||||
python-dateutil==2.7.5
|
||||
pytz==2018.5
|
||||
PyWavelets==1.0.1
|
||||
PyYAML==3.13
|
||||
scikit-image==0.14.1
|
||||
scikit-learn==0.20.0
|
||||
scipy==1.1.0
|
||||
six==1.11.0
|
||||
sklearn==0.0
|
||||
tensorboard==1.10.0
|
||||
tensorflow-gpu==1.10.1
|
||||
termcolor==1.1.0
|
||||
toolz==0.9.0
|
||||
torch==0.4.1
|
||||
torchvision==0.2.1
|
||||
scikit-learn==0.19.2
|
||||
youtube-dl==2018.9.18
|
||||
Werkzeug==0.14.1
|
||||
youtube-dl==2018.11.7
|
||||
tensorflow-gpu==1.10.1
|
||||
|
|
@ -1,148 +1,153 @@
|
|||
"""
|
||||
Django settings for comixify project.
|
||||
|
||||
Generated by 'django-admin startproject' using Django 2.0.7.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/2.0/topics/settings/
|
||||
|
||||
For the full list of settings and their values, see
|
||||
https://docs.djangoproject.com/en/2.0/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
# Use a separate file for the secret key
|
||||
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
|
||||
SECRET_KEY = f.read().strip()
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = os.environ.get('DEBUG') == 'true'
|
||||
|
||||
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'rest_framework',
|
||||
'api',
|
||||
'style_transfer',
|
||||
'comic_layout',
|
||||
'frontend',
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
]
|
||||
|
||||
ROOT_URLCONF = 'settings.urls'
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': [],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
WSGI_APPLICATION = 'settings.wsgi.application'
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.postgresql_psycopg2',
|
||||
'NAME': 'postgres',
|
||||
'USER': 'postgres',
|
||||
'HOST': 'db',
|
||||
'PORT': '5432',
|
||||
'PASSWORD': 'postgres',
|
||||
'CONN_MAX_AGE': 60,
|
||||
}
|
||||
}
|
||||
|
||||
# Password validation
|
||||
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||
},
|
||||
]
|
||||
CACHES = {
|
||||
'default': {
|
||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||
'LOCATION': 'unique-snowflake',
|
||||
}
|
||||
}
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/2.0/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = 'en-us'
|
||||
|
||||
TIME_ZONE = 'UTC'
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
USE_L10N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/2.0/howto/static-files/
|
||||
|
||||
STATIC_URL = '/static/'
|
||||
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
|
||||
|
||||
MEDIA_URL = '/media/'
|
||||
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
|
||||
|
||||
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
|
||||
MAX_FILE_SIZE = 50000000
|
||||
NUMBERS_OF_FRAMES_TO_SHOW = 10
|
||||
TMP_DIR = 'tmp/'
|
||||
GPU = True
|
||||
|
||||
FEATURE_BATCH_SIZE = 32
|
||||
DEFAULT_FRAMES_SAMPLING_MODE = 0
|
||||
DEFAULT_RL_MODE = 0
|
||||
DEFAULT_IMAGE_ASSESSMENT_MODE = 0
|
||||
DEBUG = True
|
||||
"""
|
||||
Django settings for comixify project.
|
||||
|
||||
Generated by 'django-admin startproject' using Django 2.0.7.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/2.0/topics/settings/
|
||||
|
||||
For the full list of settings and their values, see
|
||||
https://docs.djangoproject.com/en/2.0/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
# Use a separate file for the secret key
|
||||
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
|
||||
SECRET_KEY = f.read().strip()
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = os.environ.get('DEBUG') == 'true'
|
||||
|
||||
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'rest_framework',
|
||||
'api',
|
||||
'style_transfer',
|
||||
'comic_layout',
|
||||
'frontend',
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
]
|
||||
|
||||
ROOT_URLCONF = 'settings.urls'
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': [],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
WSGI_APPLICATION = 'settings.wsgi.application'
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.postgresql_psycopg2',
|
||||
'NAME': 'postgres',
|
||||
'USER': 'postgres',
|
||||
'HOST': 'db',
|
||||
'PORT': '5432',
|
||||
'PASSWORD': 'postgres',
|
||||
'CONN_MAX_AGE': 60,
|
||||
}
|
||||
}
|
||||
|
||||
# Password validation
|
||||
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||
},
|
||||
]
|
||||
CACHES = {
|
||||
'default': {
|
||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||
'LOCATION': 'unique-snowflake',
|
||||
}
|
||||
}
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/2.0/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = 'en-us'
|
||||
|
||||
TIME_ZONE = 'UTC'
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
USE_L10N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/2.0/howto/static-files/
|
||||
|
||||
STATIC_URL = '/static/'
|
||||
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
|
||||
|
||||
MEDIA_URL = '/media/'
|
||||
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
|
||||
|
||||
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
|
||||
MAX_FILE_SIZE = 50000000
|
||||
NUMBERS_OF_FRAMES_TO_SHOW = 10
|
||||
TMP_DIR = 'tmp/'
|
||||
GPU = True
|
||||
|
||||
FEATURE_BATCH_SIZE = 32
|
||||
DEFAULT_FRAMES_SAMPLING_MODE = 0
|
||||
DEFAULT_RL_MODE = 0
|
||||
DEFAULT_IMAGE_ASSESSMENT_MODE = 0
|
||||
|
||||
DEFAULT_STYLE_TRANSFER_MODE = 0
|
||||
COMIX_GAN_MODEL_PATH = os.path.join(BASE_DIR, 'ComixGAN', 'pretrained_models', 'generator_model.h5')
|
||||
MAX_FRAME_SIZE_FOR_STYLE_TRANSFER = 600
|
||||
|
||||
NIMA_MODEL_PATH = os.path.join(BASE_DIR, 'neural_image_assessment', 'pretrained_model', 'nima_model.h5')
|
||||
|
|
@ -9,21 +9,59 @@ from django.core.cache import cache
|
|||
from torch.autograd import Variable
|
||||
|
||||
from CartoonGAN.network.Transformer import Transformer
|
||||
from ComixGAN.model import ComixGAN
|
||||
from utils import profile
|
||||
|
||||
# load pretrained model
|
||||
comixGAN = ComixGAN()
|
||||
|
||||
|
||||
class StyleTransfer():
|
||||
@classmethod
|
||||
@profile
|
||||
def get_stylized_frames(cls, frames, method="cartoon_gan", gpu=settings.GPU, **kwargs):
|
||||
if method == "cartoon_gan":
|
||||
return cls._cartoon_gan_stylize(frames, gpu=gpu, **kwargs)
|
||||
def get_stylized_frames(cls, frames, style_transfer_mode=0, gpu=settings.GPU):
|
||||
if style_transfer_mode == 0:
|
||||
return cls._comix_gan_stylize(frames=frames)
|
||||
elif style_transfer_mode == 1:
|
||||
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hayao')
|
||||
elif style_transfer_mode == 2:
|
||||
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hosoda')
|
||||
|
||||
@staticmethod
|
||||
def _cartoon_gan_stylize(frames, gpu=True, **kwargs):
|
||||
style = kwargs.get("style", "Hayao")
|
||||
resize = kwargs.get("resize", 450)
|
||||
def _resize_images(frames, size=384):
|
||||
resized_images = []
|
||||
for img in frames:
|
||||
# resize image, keep aspect ratio
|
||||
h, w, _ = img.shape
|
||||
ratio = h / w
|
||||
if ratio > 1:
|
||||
h = size
|
||||
w = int(h * 1.0 / ratio)
|
||||
else:
|
||||
w = size
|
||||
h = int(w * ratio)
|
||||
resized_img = cv2.resize(img, (w, h))
|
||||
resized_images.append(resized_img)
|
||||
return resized_images
|
||||
|
||||
@classmethod
|
||||
def _comix_gan_stylize(cls, frames):
|
||||
if max(frames[0].shape) > settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER:
|
||||
frames = cls._resize_images(frames, size=settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER)
|
||||
|
||||
with comixGAN.graph.as_default():
|
||||
with comixGAN.session.as_default():
|
||||
batch_size = 4
|
||||
stylized_imgs = []
|
||||
for i in range(0, len(frames), batch_size):
|
||||
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
|
||||
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
|
||||
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 1.25))
|
||||
|
||||
return list(np.concatenate(stylized_imgs, axis=0))
|
||||
|
||||
@classmethod
|
||||
def _cartoon_gan_stylize(cls, frames, gpu=True, style='Hayao'):
|
||||
model_cache_key = 'model_cache'
|
||||
model = cache.get(model_cache_key) # get model from cache
|
||||
|
||||
|
|
@ -35,19 +73,10 @@ class StyleTransfer():
|
|||
model.cuda() if gpu else model.float()
|
||||
cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever
|
||||
|
||||
frames = cls._resize_images(frames, size=450)
|
||||
stylized_imgs = []
|
||||
for img in frames:
|
||||
# resize image, keep aspect ratio
|
||||
h, w, _ = img.shape
|
||||
ratio = h * 1.0 / w
|
||||
if ratio > 1:
|
||||
h = resize
|
||||
w = int(h * 1.0 / ratio)
|
||||
else:
|
||||
w = resize
|
||||
h = int(w * ratio)
|
||||
input_image = cv2.resize(img, (w, h))
|
||||
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
|
||||
input_image = transforms.ToTensor()(img).unsqueeze(0)
|
||||
|
||||
# preprocess, (-1, 1)
|
||||
input_image = -1 + 2 * input_image
|
||||
|
|
@ -64,6 +93,6 @@ class StyleTransfer():
|
|||
output_image = np.rollaxis(output_image, 0, 3)
|
||||
|
||||
# append image to result images
|
||||
stylized_imgs.append(output_image)
|
||||
stylized_imgs.append(255 * output_image)
|
||||
|
||||
return stylized_imgs
|
||||
|
|
|
|||
Loading…
Reference in a new issue