forked from prehistoric-systems/comixify
Add comix gan (#16)
* Add ComixGAN #1 * Add minor fixes * Reduce batch size * Fix concatenate bug * Fix GPU memory not released problem #2 * Build client * Fix style_transfer_mode bug * Improve timings * Add minor fix with GPU name * Fix tf session problem * Compile frontend * Change parameters * Fix occasional yt Error
This commit is contained in:
parent
75810f7c84
commit
1e5252b8f0
17 changed files with 352 additions and 212 deletions
24
ComixGAN/model.py
Normal file
24
ComixGAN/model.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from django.conf import settings
|
||||||
|
from keras.models import load_model
|
||||||
|
from keras_contrib.layers import InstanceNormalization
|
||||||
|
|
||||||
|
|
||||||
|
class ComixGAN:
|
||||||
|
def __init__(self):
|
||||||
|
if not os.path.exists(settings.COMIX_GAN_MODEL_PATH):
|
||||||
|
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.COMIX_GAN_MODEL_PATH)
|
||||||
|
self.graph = tf.Graph()
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.per_process_gpu_memory_fraction = 0.6
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
self.session = tf.Session(graph=self.graph, config=config)
|
||||||
|
with self.graph.as_default():
|
||||||
|
with self.session.as_default():
|
||||||
|
with tf.device('/device:GPU:0'):
|
||||||
|
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
|
||||||
|
custom_objects={'InstanceNormalization': InstanceNormalization})
|
||||||
|
|
||||||
BIN
ComixGAN/pretrained_models/generator_model.h5
Normal file
BIN
ComixGAN/pretrained_models/generator_model.h5
Normal file
Binary file not shown.
BIN
ComixGAN/pretrained_models/generator_model2.h5
Normal file
BIN
ComixGAN/pretrained_models/generator_model2.h5
Normal file
Binary file not shown.
|
|
@ -23,7 +23,7 @@ class Video(models.Model):
|
||||||
yt_pafy = pafy.new(yt_url)
|
yt_pafy = pafy.new(yt_url)
|
||||||
|
|
||||||
# Use the biggest possible quality with file size < MAX_FILE_SIZE and resolution <= 480px
|
# Use the biggest possible quality with file size < MAX_FILE_SIZE and resolution <= 480px
|
||||||
for stream in yt_pafy.videostreams:
|
for stream in reversed(yt_pafy.videostreams):
|
||||||
if stream.get_filesize() < settings.MAX_FILE_SIZE and int(stream.quality.split("x")[1]) <= 480:
|
if stream.get_filesize() < settings.MAX_FILE_SIZE and int(stream.quality.split("x")[1]) <= 480:
|
||||||
tmp_name = uuid.uuid4().hex + ".mp4"
|
tmp_name = uuid.uuid4().hex + ".mp4"
|
||||||
relative_path = jj('raw_videos', tmp_name)
|
relative_path = jj('raw_videos', tmp_name)
|
||||||
|
|
@ -34,22 +34,24 @@ class Video(models.Model):
|
||||||
else:
|
else:
|
||||||
raise TooLargeFile()
|
raise TooLargeFile()
|
||||||
|
|
||||||
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0):
|
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0, style_transfer_mode=0):
|
||||||
(keyframes, keyframes_timings), keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
|
(keyframes, keyframes_timings), keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
|
||||||
video=self,
|
video=self,
|
||||||
frames_mode=frames_mode,
|
frames_mode=frames_mode,
|
||||||
rl_mode=rl_mode,
|
rl_mode=rl_mode,
|
||||||
image_assessment_mode=image_assessment_mode
|
image_assessment_mode=image_assessment_mode
|
||||||
)
|
)
|
||||||
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes)
|
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes,
|
||||||
|
style_transfer_mode=style_transfer_mode)
|
||||||
comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes)
|
comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes)
|
||||||
|
|
||||||
timings = {
|
timings = {
|
||||||
'keyframes_extraction_time': keyframes_extraction_time,
|
'keyframes_extraction_time': keyframes_extraction_time,
|
||||||
'stylization_time': stylization_time,
|
'stylization_time': stylization_time,
|
||||||
'layout_generation_time': layout_generation_time,
|
'layout_generation_time': layout_generation_time,
|
||||||
**keyframes_timings
|
'keyframes_extraction_time_details': keyframes_timings
|
||||||
}
|
}
|
||||||
|
|
||||||
return comic_image, timings
|
return comic_image, timings
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,7 +63,7 @@ class Comic(models.Model):
|
||||||
@profile
|
@profile
|
||||||
def create_from_nparray(cls, nparray_file, video):
|
def create_from_nparray(cls, nparray_file, video):
|
||||||
if nparray_file.max() <= 1:
|
if nparray_file.max() <= 1:
|
||||||
nparray_file = (nparray_file * 255).astype(int)
|
nparray_file = (nparray_file).astype(int)
|
||||||
tmp_name = uuid.uuid4().hex + ".png"
|
tmp_name = uuid.uuid4().hex + ".png"
|
||||||
cv2.imwrite(jj(settings.TMP_DIR, tmp_name), nparray_file)
|
cv2.imwrite(jj(settings.TMP_DIR, tmp_name), nparray_file)
|
||||||
with open(jj(settings.TMP_DIR, tmp_name), mode="rb") as tmp_file:
|
with open(jj(settings.TMP_DIR, tmp_name), mode="rb") as tmp_file:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,10 @@ class VideoSerializer(serializers.Serializer):
|
||||||
file = serializers.FileField()
|
file = serializers.FileField()
|
||||||
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
|
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
|
||||||
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
|
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
|
||||||
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
|
||||||
|
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
||||||
|
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
|
||||||
|
default=settings.DEFAULT_STYLE_TRANSFER_MODE)
|
||||||
|
|
||||||
def validate(self, attrs):
|
def validate(self, attrs):
|
||||||
file = attrs.get("file")
|
file = attrs.get("file")
|
||||||
|
|
@ -23,4 +26,7 @@ class YouTubeDownloadSerializer(serializers.Serializer):
|
||||||
url = serializers.URLField()
|
url = serializers.URLField()
|
||||||
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
|
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
|
||||||
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
|
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
|
||||||
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1,
|
||||||
|
default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
|
||||||
|
style_transfer_mode = serializers.IntegerField(min_value=0, max_value=2,
|
||||||
|
default=settings.DEFAULT_STYLE_TRANSFER_MODE)
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,8 @@ class Comixify(APIView):
|
||||||
comic_image, timings = video.create_comic(
|
comic_image, timings = video.create_comic(
|
||||||
frames_mode=serializer.validated_data["frames_mode"],
|
frames_mode=serializer.validated_data["frames_mode"],
|
||||||
rl_mode=serializer.validated_data["rl_mode"],
|
rl_mode=serializer.validated_data["rl_mode"],
|
||||||
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
|
image_assessment_mode=serializer.validated_data["image_assessment_mode"],
|
||||||
|
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
|
||||||
)
|
)
|
||||||
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
|
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
|
||||||
timings['from_nparray_time'] = from_nparray_time
|
timings['from_nparray_time'] = from_nparray_time
|
||||||
|
|
@ -54,7 +55,8 @@ class ComixifyFromYoutube(APIView):
|
||||||
comic_image, timings = video.create_comic(
|
comic_image, timings = video.create_comic(
|
||||||
frames_mode=serializer.validated_data["frames_mode"],
|
frames_mode=serializer.validated_data["frames_mode"],
|
||||||
rl_mode=serializer.validated_data["rl_mode"],
|
rl_mode=serializer.validated_data["rl_mode"],
|
||||||
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
|
image_assessment_mode=serializer.validated_data["image_assessment_mode"],
|
||||||
|
style_transfer_mode=serializer.validated_data["style_transfer_mode"],
|
||||||
)
|
)
|
||||||
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
|
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
|
||||||
timings['from_nparray_time'] = from_nparray_time
|
timings['from_nparray_time'] = from_nparray_time
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,6 @@ class LayoutGenerator():
|
||||||
def _pad_images(frames):
|
def _pad_images(frames):
|
||||||
padded_result_imgs = []
|
padded_result_imgs = []
|
||||||
for img in frames:
|
for img in frames:
|
||||||
padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(1, 1, 1))
|
padded_img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=(255, 255, 255))
|
||||||
padded_result_imgs.append(padded_img)
|
padded_result_imgs.append(padded_img)
|
||||||
return padded_result_imgs
|
return padded_result_imgs
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,7 @@ RUN apt-get update && apt-get install -y apt-utils software-properties-common &&
|
||||||
libsnappy-dev protobuf-compiler \
|
libsnappy-dev protobuf-compiler \
|
||||||
python-numpy python-setuptools python-scipy \
|
python-numpy python-setuptools python-scipy \
|
||||||
libavformat-dev libswscale-dev unzip && \
|
libavformat-dev libswscale-dev unzip && \
|
||||||
python3.6 -m pip install --upgrade pip && \
|
python3.6 -m pip install --upgrade pip
|
||||||
python3.6 -m pip install jupyter ipywidgets jupyterlab && \
|
|
||||||
python3.6 -m pip install h5py keras && \
|
|
||||||
python3.6 -m pip install scikit-image opencv-contrib-python pyyaml
|
|
||||||
|
|
||||||
RUN mkdir /comixify
|
RUN mkdir /comixify
|
||||||
COPY ./Makefile.config /comixify/Makefile.config
|
COPY ./Makefile.config /comixify/Makefile.config
|
||||||
|
|
@ -45,7 +42,8 @@ RUN echo "$CAFFE_ROOT/build/lib" >> /etc/ld.so.conf.d/caffe.conf && ldconfig &&
|
||||||
WORKDIR /comixify
|
WORKDIR /comixify
|
||||||
COPY . /comixify
|
COPY . /comixify
|
||||||
RUN unzip popularity/pretrained_model/svr_test_11.10.sk.zip -d popularity/pretrained_model/ && \
|
RUN unzip popularity/pretrained_model/svr_test_11.10.sk.zip -d popularity/pretrained_model/ && \
|
||||||
python3.6 -m pip install -r requirements.txt
|
python3.6 -m pip install -r requirements.txt && \
|
||||||
|
python3.6 -m pip install git+https://www.github.com/keras-team/keras-contrib.git
|
||||||
|
|
||||||
|
|
||||||
# Port to expose
|
# Port to expose
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,8 @@ class App extends React.Component {
|
||||||
result_comics: null,
|
result_comics: null,
|
||||||
framesMode: "0",
|
framesMode: "0",
|
||||||
rlMode: "0",
|
rlMode: "0",
|
||||||
imageAssessment: "0"
|
imageAssessment: "0",
|
||||||
|
styleTransferMode: "0",
|
||||||
};
|
};
|
||||||
this.onVideoDrop = this.onVideoDrop.bind(this);
|
this.onVideoDrop = this.onVideoDrop.bind(this);
|
||||||
this.onModelChange = this.onModelChange.bind(this);
|
this.onModelChange = this.onModelChange.bind(this);
|
||||||
|
|
@ -40,6 +41,7 @@ class App extends React.Component {
|
||||||
this.onYouTubeSubmit = this.onYouTubeSubmit.bind(this);
|
this.onYouTubeSubmit = this.onYouTubeSubmit.bind(this);
|
||||||
this.onSamplingChange = this.onSamplingChange.bind(this);
|
this.onSamplingChange = this.onSamplingChange.bind(this);
|
||||||
this.onImageAssessmentChange = this.onImageAssessmentChange.bind(this);
|
this.onImageAssessmentChange = this.onImageAssessmentChange.bind(this);
|
||||||
|
this.styleTransferChange = this.styleTransferChange.bind(this);
|
||||||
}
|
}
|
||||||
static onVideoUploadProgress(progressEvent) {
|
static onVideoUploadProgress(progressEvent) {
|
||||||
let percentCompleted = Math.round(
|
let percentCompleted = Math.round(
|
||||||
|
|
@ -52,6 +54,12 @@ class App extends React.Component {
|
||||||
this.setState({
|
this.setState({
|
||||||
rlMode: value
|
rlMode: value
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
styleTransferChange(e) {
|
||||||
|
let value = e.currentTarget.value;
|
||||||
|
this.setState({
|
||||||
|
styleTransferMode: value
|
||||||
|
})
|
||||||
}
|
}
|
||||||
onSamplingChange(e) {
|
onSamplingChange(e) {
|
||||||
let value = e.currentTarget.value;
|
let value = e.currentTarget.value;
|
||||||
|
|
@ -78,12 +86,13 @@ class App extends React.Component {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
processVideo(video) {
|
processVideo(video) {
|
||||||
let { framesMode, rlMode, imageAssessment } = this.state;
|
let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
|
||||||
let data = new FormData();
|
let data = new FormData();
|
||||||
data.append("file", video);
|
data.append("file", video);
|
||||||
data.set('frames_mode', parseInt(framesMode));
|
data.set('frames_mode', parseInt(framesMode));
|
||||||
data.set('rl_mode', parseInt(rlMode));
|
data.set('rl_mode', parseInt(rlMode));
|
||||||
data.set("image_assessment_mode", parseInt(imageAssessment));
|
data.set("image_assessment_mode", parseInt(imageAssessment));
|
||||||
|
data.set('style_transfer_mode', parseInt(styleTransferMode));
|
||||||
post(COMIXIFY_API, data, {
|
post(COMIXIFY_API, data, {
|
||||||
headers: { "content-type": "multipart/form-data" },
|
headers: { "content-type": "multipart/form-data" },
|
||||||
onUploadProgress: App.onVideoUploadProgress
|
onUploadProgress: App.onVideoUploadProgress
|
||||||
|
|
@ -111,12 +120,13 @@ class App extends React.Component {
|
||||||
this.processVideo(files[0]);
|
this.processVideo(files[0]);
|
||||||
}
|
}
|
||||||
submitYouTube(link) {
|
submitYouTube(link) {
|
||||||
let { framesMode, rlMode, imageAssessment } = this.state;
|
let { framesMode, rlMode, imageAssessment, styleTransferMode } = this.state;
|
||||||
post(FROM_YOUTUBE_API, {
|
post(FROM_YOUTUBE_API, {
|
||||||
url: link,
|
url: link,
|
||||||
frames_mode: parseInt(framesMode),
|
frames_mode: parseInt(framesMode),
|
||||||
rl_mode: parseInt(rlMode),
|
rl_mode: parseInt(rlMode),
|
||||||
image_assessment_mode: parseInt(imageAssessment)
|
image_assessment_mode: parseInt(imageAssessment),
|
||||||
|
style_transfer_mode: parseInt(styleTransferMode)
|
||||||
})
|
})
|
||||||
.then(this.handleResponse)
|
.then(this.handleResponse)
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
|
|
@ -143,7 +153,7 @@ class App extends React.Component {
|
||||||
}
|
}
|
||||||
render() {
|
render() {
|
||||||
let {
|
let {
|
||||||
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment
|
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment, styleTransferMode
|
||||||
} = this.state;
|
} = this.state;
|
||||||
let showUsage = [
|
let showUsage = [
|
||||||
App.appStates.INITIAL,
|
App.appStates.INITIAL,
|
||||||
|
|
@ -231,6 +241,36 @@ class App extends React.Component {
|
||||||
onChange={this.onImageAssessmentChange}
|
onChange={this.onImageAssessmentChange}
|
||||||
/>
|
/>
|
||||||
<label htmlFor="image-assessment-1">Popularity</label>
|
<label htmlFor="image-assessment-1">Popularity</label>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span>Style Transfer model:</span>
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="style-model"
|
||||||
|
id="style-model-0"
|
||||||
|
value="0"
|
||||||
|
checked={styleTransferMode === "0"}
|
||||||
|
onChange={this.styleTransferChange}
|
||||||
|
/>
|
||||||
|
<label htmlFor="style-model-0">ComixGAN</label>
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="style-model"
|
||||||
|
id="style-model-1"
|
||||||
|
value="1"
|
||||||
|
checked={styleTransferMode === "1"}
|
||||||
|
onChange={this.styleTransferChange}
|
||||||
|
/>
|
||||||
|
<label htmlFor="style-model-1">CartoonGAN-Hayao</label>
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="style-model"
|
||||||
|
id="style-model-2"
|
||||||
|
value="2"
|
||||||
|
checked={styleTransferMode === "2"}
|
||||||
|
onChange={this.styleTransferChange}
|
||||||
|
/>
|
||||||
|
<label htmlFor="style-model-2">CartoonGAN-Hosoda</label>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -60,7 +60,7 @@
|
||||||
|
|
||||||
<i>* Max size of video is 50 MB</i>
|
<i>* Max size of video is 50 MB</i>
|
||||||
|
|
||||||
<h4>Respose:</h4>
|
<h4>Response:</h4>
|
||||||
<code>
|
<code>
|
||||||
{
|
{
|
||||||
<br>
|
<br>
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
from utils import jj, profile
|
from utils import jj, profile
|
||||||
from keyframes_rl.models import DSN
|
from keyframes_rl.models import DSN
|
||||||
from popularity.models import PopularityPredictor
|
from popularity.models import PopularityPredictor
|
||||||
from neural_image_assessment.models import NeuralImageAssessment
|
from neural_image_assessment.model import NeuralImageAssessment
|
||||||
from keyframes.kts import cpd_auto
|
from keyframes.kts import cpd_auto
|
||||||
from keyframes.utils import batch
|
from keyframes.utils import batch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,27 @@
|
||||||
import os
|
|
||||||
import errno
|
import errno
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from keras.models import load_model
|
|
||||||
from keras.preprocessing.image import img_to_array
|
|
||||||
from keras.applications.nasnet import preprocess_input
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from keras.applications.nasnet import preprocess_input
|
||||||
|
from keras.models import load_model
|
||||||
MODEL_PATH = 'neural_image_assessment/pretrained_model/nima_model.h5'
|
from keras.preprocessing.image import img_to_array
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
|
||||||
class NeuralImageAssessment:
|
class NeuralImageAssessment:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not os.path.exists(MODEL_PATH):
|
if not os.path.exists(settings.NIMA_MODEL_PATH):
|
||||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), MODEL_PATH)
|
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.NIMA_MODEL_PATH)
|
||||||
self.graph = tf.Graph()
|
self.graph = tf.Graph()
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
config.gpu_options.per_process_gpu_memory_fraction = 0.1
|
config.gpu_options.per_process_gpu_memory_fraction = 0.6
|
||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
self.session = tf.Session(graph=self.graph, config=config)
|
self.session = tf.Session(graph=self.graph, config=config)
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
self.model = load_model(MODEL_PATH)
|
with self.session.as_default():
|
||||||
|
self.model = load_model(settings.NIMA_MODEL_PATH)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
|
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
|
||||||
|
|
@ -46,13 +46,14 @@ class NeuralImageAssessment:
|
||||||
|
|
||||||
def get_assessment_score(self, img_array):
|
def get_assessment_score(self, img_array):
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
target_size = (224, 224)
|
with self.session.as_default():
|
||||||
img = NeuralImageAssessment.resize_image(img_array, target_size)
|
target_size = (224, 224)
|
||||||
x = img_to_array(img)
|
img = NeuralImageAssessment.resize_image(img_array, target_size)
|
||||||
x = np.expand_dims(x, axis=0)
|
x = img_to_array(img)
|
||||||
x = preprocess_input(x)
|
x = np.expand_dims(x, axis=0)
|
||||||
|
x = preprocess_input(x)
|
||||||
|
|
||||||
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
|
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
|
||||||
mean = NeuralImageAssessment.mean_score(scores)
|
mean = NeuralImageAssessment.mean_score(scores)
|
||||||
|
|
||||||
return mean
|
return mean
|
||||||
|
|
@ -72,6 +72,10 @@ echo 'export LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
|
||||||
source ~/.bashrc
|
source ~/.bashrc
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK CUDNN VERSION
|
||||||
|
cat /usr/include/x86_64-linux-gnu/cudnn_v*.h | grep CUDNN_MAJOR -A 2
|
||||||
|
|
||||||
|
|
||||||
# INSTALL PACKAGES
|
# INSTALL PACKAGES
|
||||||
export LC_ALL="en_US.UTF-8"
|
export LC_ALL="en_US.UTF-8"
|
||||||
export LC_CTYPE="en_US.UTF-8"
|
export LC_CTYPE="en_US.UTF-8"
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,45 @@
|
||||||
|
astor==0.7.1
|
||||||
|
cloudpickle==0.6.1
|
||||||
|
cycler==0.10.0
|
||||||
|
dask==0.20.0
|
||||||
|
decorator==4.3.0
|
||||||
Django==2.0.7
|
Django==2.0.7
|
||||||
django-rest-framework==0.1.0
|
django-rest-framework==0.1.0
|
||||||
djangorestframework==3.8.2
|
djangorestframework==3.8.2
|
||||||
|
gast==0.2.0
|
||||||
|
grpcio==1.16.0
|
||||||
gunicorn==19.9.0
|
gunicorn==19.9.0
|
||||||
|
h5py==2.8.0
|
||||||
|
Keras==2.2.4
|
||||||
|
Keras-Applications==1.0.6
|
||||||
|
Keras-Preprocessing==1.0.5
|
||||||
|
kiwisolver==1.0.1
|
||||||
|
Markdown==3.0.1
|
||||||
|
matplotlib==3.0.1
|
||||||
|
networkx==2.2
|
||||||
numpy==1.14.5
|
numpy==1.14.5
|
||||||
|
opencv-contrib-python==3.4.3.18
|
||||||
opencv-python==3.4.2.17
|
opencv-python==3.4.2.17
|
||||||
pafy==0.5.4
|
pafy==0.5.4
|
||||||
Pillow==5.2.0
|
Pillow==5.2.0
|
||||||
|
protobuf==3.6.1
|
||||||
psycopg2==2.7.5
|
psycopg2==2.7.5
|
||||||
|
pyparsing==2.3.0
|
||||||
|
python-dateutil==2.7.5
|
||||||
pytz==2018.5
|
pytz==2018.5
|
||||||
|
PyWavelets==1.0.1
|
||||||
|
PyYAML==3.13
|
||||||
|
scikit-image==0.14.1
|
||||||
|
scikit-learn==0.20.0
|
||||||
|
scipy==1.1.0
|
||||||
six==1.11.0
|
six==1.11.0
|
||||||
|
sklearn==0.0
|
||||||
|
tensorboard==1.10.0
|
||||||
|
tensorflow-gpu==1.10.1
|
||||||
|
termcolor==1.1.0
|
||||||
|
toolz==0.9.0
|
||||||
torch==0.4.1
|
torch==0.4.1
|
||||||
torchvision==0.2.1
|
torchvision==0.2.1
|
||||||
scikit-learn==0.19.2
|
Werkzeug==0.14.1
|
||||||
youtube-dl==2018.9.18
|
youtube-dl==2018.11.7
|
||||||
tensorflow-gpu==1.10.1
|
tensorflow-gpu==1.10.1
|
||||||
|
|
@ -1,148 +1,153 @@
|
||||||
"""
|
"""
|
||||||
Django settings for comixify project.
|
Django settings for comixify project.
|
||||||
|
|
||||||
Generated by 'django-admin startproject' using Django 2.0.7.
|
Generated by 'django-admin startproject' using Django 2.0.7.
|
||||||
|
|
||||||
For more information on this file, see
|
For more information on this file, see
|
||||||
https://docs.djangoproject.com/en/2.0/topics/settings/
|
https://docs.djangoproject.com/en/2.0/topics/settings/
|
||||||
|
|
||||||
For the full list of settings and their values, see
|
For the full list of settings and their values, see
|
||||||
https://docs.djangoproject.com/en/2.0/ref/settings/
|
https://docs.djangoproject.com/en/2.0/ref/settings/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
|
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
|
||||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
# Quick-start development settings - unsuitable for production
|
# Quick-start development settings - unsuitable for production
|
||||||
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
|
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
|
||||||
|
|
||||||
# SECURITY WARNING: keep the secret key used in production secret!
|
# SECURITY WARNING: keep the secret key used in production secret!
|
||||||
# Use a separate file for the secret key
|
# Use a separate file for the secret key
|
||||||
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
|
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
|
||||||
SECRET_KEY = f.read().strip()
|
SECRET_KEY = f.read().strip()
|
||||||
|
|
||||||
# SECURITY WARNING: don't run with debug turned on in production!
|
# SECURITY WARNING: don't run with debug turned on in production!
|
||||||
DEBUG = os.environ.get('DEBUG') == 'true'
|
DEBUG = os.environ.get('DEBUG') == 'true'
|
||||||
|
|
||||||
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
|
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
|
||||||
|
|
||||||
# Application definition
|
# Application definition
|
||||||
|
|
||||||
INSTALLED_APPS = [
|
INSTALLED_APPS = [
|
||||||
'django.contrib.admin',
|
'django.contrib.admin',
|
||||||
'django.contrib.auth',
|
'django.contrib.auth',
|
||||||
'django.contrib.contenttypes',
|
'django.contrib.contenttypes',
|
||||||
'django.contrib.sessions',
|
'django.contrib.sessions',
|
||||||
'django.contrib.messages',
|
'django.contrib.messages',
|
||||||
'django.contrib.staticfiles',
|
'django.contrib.staticfiles',
|
||||||
'rest_framework',
|
'rest_framework',
|
||||||
'api',
|
'api',
|
||||||
'style_transfer',
|
'style_transfer',
|
||||||
'comic_layout',
|
'comic_layout',
|
||||||
'frontend',
|
'frontend',
|
||||||
]
|
]
|
||||||
|
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
'django.middleware.security.SecurityMiddleware',
|
'django.middleware.security.SecurityMiddleware',
|
||||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||||
'django.middleware.common.CommonMiddleware',
|
'django.middleware.common.CommonMiddleware',
|
||||||
'django.middleware.csrf.CsrfViewMiddleware',
|
'django.middleware.csrf.CsrfViewMiddleware',
|
||||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||||
'django.contrib.messages.middleware.MessageMiddleware',
|
'django.contrib.messages.middleware.MessageMiddleware',
|
||||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||||
]
|
]
|
||||||
|
|
||||||
ROOT_URLCONF = 'settings.urls'
|
ROOT_URLCONF = 'settings.urls'
|
||||||
|
|
||||||
TEMPLATES = [
|
TEMPLATES = [
|
||||||
{
|
{
|
||||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||||
'DIRS': [],
|
'DIRS': [],
|
||||||
'APP_DIRS': True,
|
'APP_DIRS': True,
|
||||||
'OPTIONS': {
|
'OPTIONS': {
|
||||||
'context_processors': [
|
'context_processors': [
|
||||||
'django.template.context_processors.debug',
|
'django.template.context_processors.debug',
|
||||||
'django.template.context_processors.request',
|
'django.template.context_processors.request',
|
||||||
'django.contrib.auth.context_processors.auth',
|
'django.contrib.auth.context_processors.auth',
|
||||||
'django.contrib.messages.context_processors.messages',
|
'django.contrib.messages.context_processors.messages',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
WSGI_APPLICATION = 'settings.wsgi.application'
|
WSGI_APPLICATION = 'settings.wsgi.application'
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
|
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
|
||||||
|
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
'default': {
|
'default': {
|
||||||
'ENGINE': 'django.db.backends.postgresql_psycopg2',
|
'ENGINE': 'django.db.backends.postgresql_psycopg2',
|
||||||
'NAME': 'postgres',
|
'NAME': 'postgres',
|
||||||
'USER': 'postgres',
|
'USER': 'postgres',
|
||||||
'HOST': 'db',
|
'HOST': 'db',
|
||||||
'PORT': '5432',
|
'PORT': '5432',
|
||||||
'PASSWORD': 'postgres',
|
'PASSWORD': 'postgres',
|
||||||
'CONN_MAX_AGE': 60,
|
'CONN_MAX_AGE': 60,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Password validation
|
# Password validation
|
||||||
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
|
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
|
||||||
|
|
||||||
AUTH_PASSWORD_VALIDATORS = [
|
AUTH_PASSWORD_VALIDATORS = [
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
CACHES = {
|
CACHES = {
|
||||||
'default': {
|
'default': {
|
||||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||||
'LOCATION': 'unique-snowflake',
|
'LOCATION': 'unique-snowflake',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Internationalization
|
# Internationalization
|
||||||
# https://docs.djangoproject.com/en/2.0/topics/i18n/
|
# https://docs.djangoproject.com/en/2.0/topics/i18n/
|
||||||
|
|
||||||
LANGUAGE_CODE = 'en-us'
|
LANGUAGE_CODE = 'en-us'
|
||||||
|
|
||||||
TIME_ZONE = 'UTC'
|
TIME_ZONE = 'UTC'
|
||||||
|
|
||||||
USE_I18N = True
|
USE_I18N = True
|
||||||
|
|
||||||
USE_L10N = True
|
USE_L10N = True
|
||||||
|
|
||||||
USE_TZ = True
|
USE_TZ = True
|
||||||
|
|
||||||
# Static files (CSS, JavaScript, Images)
|
# Static files (CSS, JavaScript, Images)
|
||||||
# https://docs.djangoproject.com/en/2.0/howto/static-files/
|
# https://docs.djangoproject.com/en/2.0/howto/static-files/
|
||||||
|
|
||||||
STATIC_URL = '/static/'
|
STATIC_URL = '/static/'
|
||||||
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
|
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
|
||||||
|
|
||||||
MEDIA_URL = '/media/'
|
MEDIA_URL = '/media/'
|
||||||
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
|
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
|
||||||
|
|
||||||
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
|
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
|
||||||
MAX_FILE_SIZE = 50000000
|
MAX_FILE_SIZE = 50000000
|
||||||
NUMBERS_OF_FRAMES_TO_SHOW = 10
|
NUMBERS_OF_FRAMES_TO_SHOW = 10
|
||||||
TMP_DIR = 'tmp/'
|
TMP_DIR = 'tmp/'
|
||||||
GPU = True
|
GPU = True
|
||||||
|
|
||||||
FEATURE_BATCH_SIZE = 32
|
FEATURE_BATCH_SIZE = 32
|
||||||
DEFAULT_FRAMES_SAMPLING_MODE = 0
|
DEFAULT_FRAMES_SAMPLING_MODE = 0
|
||||||
DEFAULT_RL_MODE = 0
|
DEFAULT_RL_MODE = 0
|
||||||
DEFAULT_IMAGE_ASSESSMENT_MODE = 0
|
DEFAULT_IMAGE_ASSESSMENT_MODE = 0
|
||||||
DEBUG = True
|
|
||||||
|
DEFAULT_STYLE_TRANSFER_MODE = 0
|
||||||
|
COMIX_GAN_MODEL_PATH = os.path.join(BASE_DIR, 'ComixGAN', 'pretrained_models', 'generator_model.h5')
|
||||||
|
MAX_FRAME_SIZE_FOR_STYLE_TRANSFER = 600
|
||||||
|
|
||||||
|
NIMA_MODEL_PATH = os.path.join(BASE_DIR, 'neural_image_assessment', 'pretrained_model', 'nima_model.h5')
|
||||||
|
|
@ -9,21 +9,59 @@ from django.core.cache import cache
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
|
||||||
from CartoonGAN.network.Transformer import Transformer
|
from CartoonGAN.network.Transformer import Transformer
|
||||||
|
from ComixGAN.model import ComixGAN
|
||||||
from utils import profile
|
from utils import profile
|
||||||
|
|
||||||
|
# load pretrained model
|
||||||
|
comixGAN = ComixGAN()
|
||||||
|
|
||||||
|
|
||||||
class StyleTransfer():
|
class StyleTransfer():
|
||||||
@classmethod
|
@classmethod
|
||||||
@profile
|
@profile
|
||||||
def get_stylized_frames(cls, frames, method="cartoon_gan", gpu=settings.GPU, **kwargs):
|
def get_stylized_frames(cls, frames, style_transfer_mode=0, gpu=settings.GPU):
|
||||||
if method == "cartoon_gan":
|
if style_transfer_mode == 0:
|
||||||
return cls._cartoon_gan_stylize(frames, gpu=gpu, **kwargs)
|
return cls._comix_gan_stylize(frames=frames)
|
||||||
|
elif style_transfer_mode == 1:
|
||||||
|
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hayao')
|
||||||
|
elif style_transfer_mode == 2:
|
||||||
|
return cls._cartoon_gan_stylize(frames, gpu=gpu, style='Hosoda')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _cartoon_gan_stylize(frames, gpu=True, **kwargs):
|
def _resize_images(frames, size=384):
|
||||||
style = kwargs.get("style", "Hayao")
|
resized_images = []
|
||||||
resize = kwargs.get("resize", 450)
|
for img in frames:
|
||||||
|
# resize image, keep aspect ratio
|
||||||
|
h, w, _ = img.shape
|
||||||
|
ratio = h / w
|
||||||
|
if ratio > 1:
|
||||||
|
h = size
|
||||||
|
w = int(h * 1.0 / ratio)
|
||||||
|
else:
|
||||||
|
w = size
|
||||||
|
h = int(w * ratio)
|
||||||
|
resized_img = cv2.resize(img, (w, h))
|
||||||
|
resized_images.append(resized_img)
|
||||||
|
return resized_images
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _comix_gan_stylize(cls, frames):
|
||||||
|
if max(frames[0].shape) > settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER:
|
||||||
|
frames = cls._resize_images(frames, size=settings.MAX_FRAME_SIZE_FOR_STYLE_TRANSFER)
|
||||||
|
|
||||||
|
with comixGAN.graph.as_default():
|
||||||
|
with comixGAN.session.as_default():
|
||||||
|
batch_size = 4
|
||||||
|
stylized_imgs = []
|
||||||
|
for i in range(0, len(frames), batch_size):
|
||||||
|
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
|
||||||
|
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
|
||||||
|
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 1.25))
|
||||||
|
|
||||||
|
return list(np.concatenate(stylized_imgs, axis=0))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _cartoon_gan_stylize(cls, frames, gpu=True, style='Hayao'):
|
||||||
model_cache_key = 'model_cache'
|
model_cache_key = 'model_cache'
|
||||||
model = cache.get(model_cache_key) # get model from cache
|
model = cache.get(model_cache_key) # get model from cache
|
||||||
|
|
||||||
|
|
@ -35,19 +73,10 @@ class StyleTransfer():
|
||||||
model.cuda() if gpu else model.float()
|
model.cuda() if gpu else model.float()
|
||||||
cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever
|
cache.set(model_cache_key, model, None) # None is the timeout parameter. It means cache forever
|
||||||
|
|
||||||
|
frames = cls._resize_images(frames, size=450)
|
||||||
stylized_imgs = []
|
stylized_imgs = []
|
||||||
for img in frames:
|
for img in frames:
|
||||||
# resize image, keep aspect ratio
|
input_image = transforms.ToTensor()(img).unsqueeze(0)
|
||||||
h, w, _ = img.shape
|
|
||||||
ratio = h * 1.0 / w
|
|
||||||
if ratio > 1:
|
|
||||||
h = resize
|
|
||||||
w = int(h * 1.0 / ratio)
|
|
||||||
else:
|
|
||||||
w = resize
|
|
||||||
h = int(w * ratio)
|
|
||||||
input_image = cv2.resize(img, (w, h))
|
|
||||||
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
|
|
||||||
|
|
||||||
# preprocess, (-1, 1)
|
# preprocess, (-1, 1)
|
||||||
input_image = -1 + 2 * input_image
|
input_image = -1 + 2 * input_image
|
||||||
|
|
@ -64,6 +93,6 @@ class StyleTransfer():
|
||||||
output_image = np.rollaxis(output_image, 0, 3)
|
output_image = np.rollaxis(output_image, 0, 3)
|
||||||
|
|
||||||
# append image to result images
|
# append image to result images
|
||||||
stylized_imgs.append(output_image)
|
stylized_imgs.append(255 * output_image)
|
||||||
|
|
||||||
return stylized_imgs
|
return stylized_imgs
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue