Add NIMA (#13)

* Add frontend

* Neural Image Assessment (NIMA) model

* Add GPU mem fraction to NIMA model

* Add backend logic

* Remove whitespace

* NIMA score from image array

* Remove NIMA caching

* Add nima model as global

* NIMA model - add graph to class

* Set DEBUG=True

* Raise exception if NIMA model file is not found
This commit is contained in:
Adam Svystun 2018-11-09 00:44:19 +01:00 committed by GitHub
parent 8c627690f1
commit d453748661
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 291 additions and 175 deletions

View file

@ -34,11 +34,12 @@ class Video(models.Model):
else:
raise TooLargeFile()
def create_comic(self, frames_mode=0, rl_mode=0):
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0):
keyframes, keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
video=self,
frames_mode=frames_mode,
rl_mode=rl_mode
rl_mode=rl_mode,
image_assessment_mode=image_assessment_mode
)
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes)
comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes)

View file

@ -8,6 +8,7 @@ class VideoSerializer(serializers.Serializer):
file = serializers.FileField()
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)
def validate(self, attrs):
file = attrs.get("file")
@ -22,3 +23,4 @@ class YouTubeDownloadSerializer(serializers.Serializer):
url = serializers.URLField()
frames_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_FRAMES_SAMPLING_MODE)
rl_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_RL_MODE)
image_assessment_mode = serializers.IntegerField(min_value=0, max_value=1, default=settings.DEFAULT_IMAGE_ASSESSMENT_MODE)

View file

@ -21,7 +21,8 @@ class Comixify(APIView):
video = Video.objects.create(file=video_file)
comic_image, timings = video.create_comic(
frames_mode=serializer.validated_data["frames_mode"],
rl_mode=serializer.validated_data["rl_mode"]
rl_mode=serializer.validated_data["rl_mode"],
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
)
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
timings['from_nparray_time'] = from_nparray_time
@ -52,7 +53,8 @@ class ComixifyFromYoutube(APIView):
video.save()
comic_image, timings = video.create_comic(
frames_mode=serializer.validated_data["frames_mode"],
rl_mode=serializer.validated_data["rl_mode"]
rl_mode=serializer.validated_data["rl_mode"],
image_assessment_mode=serializer.validated_data["image_assessment_mode"]
)
comic, from_nparray_time = Comic.create_from_nparray(comic_image, video)
timings['from_nparray_time'] = from_nparray_time

View file

@ -31,13 +31,15 @@ class App extends React.Component {
drop_errors: [],
result_comics: null,
framesMode: "0",
rlMode: "0"
rlMode: "0",
imageAssessment: "0"
};
this.onVideoDrop = this.onVideoDrop.bind(this);
this.onModelChange = this.onModelChange.bind(this);
this.handleResponse = this.handleResponse.bind(this);
this.onYouTubeSubmit = this.onYouTubeSubmit.bind(this);
this.onSamplingChange = this.onSamplingChange.bind(this);
this.onImageAssessmentChange = this.onImageAssessmentChange.bind(this);
}
static onVideoUploadProgress(progressEvent) {
let percentCompleted = Math.round(
@ -56,6 +58,12 @@ class App extends React.Component {
this.setState({
framesMode: value
})
}
onImageAssessmentChange(e) {
let value = e.currentTarget.value;
this.setState({
imageAssessment: value
})
}
handleResponse(res) {
if (res.data["status_message"] === "ok") {
@ -70,11 +78,12 @@ class App extends React.Component {
}
}
processVideo(video) {
let { framesMode, rlMode } = this.state
let { framesMode, rlMode, imageAssessment } = this.state;
let data = new FormData();
data.append("file", video);
data.set('frames_mode', parseInt(framesMode));
data.set('rl_mode', parseInt(rlMode));
data.set("image_assessment_mode", parseInt(imageAssessment));
post(COMIXIFY_API, data, {
headers: { "content-type": "multipart/form-data" },
onUploadProgress: App.onVideoUploadProgress
@ -102,11 +111,12 @@ class App extends React.Component {
this.processVideo(files[0]);
}
submitYouTube(link) {
let { framesMode, rlMode } = this.state;
let { framesMode, rlMode, imageAssessment } = this.state;
post(FROM_YOUTUBE_API, {
url: link,
frames_mode: parseInt(framesMode),
rl_mode: parseInt(rlMode)
rl_mode: parseInt(rlMode),
image_assessment_mode: parseInt(imageAssessment)
})
.then(this.handleResponse)
.catch(err => {
@ -132,7 +142,9 @@ class App extends React.Component {
});
}
render() {
let { state, drop_errors, result_comics, framesMode, rlMode, videoId } = this.state;
let {
state, drop_errors, result_comics, framesMode, rlMode, videoId, imageAssessment
} = this.state;
let showUsage = [
App.appStates.INITIAL,
App.appStates.UPLOAD_ERROR,
@ -198,9 +210,30 @@ class App extends React.Component {
onChange={this.onModelChange}
/>
<label htmlFor="model-1">+VTW model</label>
</div>
<div>
<span>Image assessment:</span>
<input
type="radio"
name="image-assessment"
id="image-assessment-0"
value="0"
checked={imageAssessment === "0"}
onChange={this.onImageAssessmentChange}
/>
<label htmlFor="image-assessment-0">NIMA</label>
<input
type="radio"
name="image-assessment"
id="image-assessment-1"
value="1"
checked={imageAssessment === "1"}
onChange={this.onImageAssessmentChange}
/>
<label htmlFor="image-assessment-1">Popularity</label>
</div>
</div>
)}
)}
{showUsage && (
<Dropzone
onDrop={this.onVideoDrop}

File diff suppressed because one or more lines are too long

View file

@ -18,17 +18,20 @@ import logging
from utils import jj, profile
from keyframes_rl.models import DSN
from popularity.models import PopularityPredictor
from neural_image_assessment.models import NeuralImageAssessment
from keyframes.kts import cpd_auto
from keyframes.utils import batch
logger = logging.getLogger(__name__)
nima_model = NeuralImageAssessment()
class KeyFramesExtractor:
@classmethod
@profile
def get_keyframes(cls, video, gpu=settings.GPU, features_batch_size=settings.FEATURE_BATCH_SIZE,
frames_mode=0, rl_mode=0):
frames_mode=0, rl_mode=0, image_assessment_mode=0):
frames_paths, all_frames_tmp_dir = cls._get_all_frames(video, mode=frames_mode)
frames = cls._get_frames(frames_paths)
features = cls._get_features(frames, gpu, features_batch_size)
@ -36,7 +39,7 @@ class KeyFramesExtractor:
change_points, frames_per_segment = cls._get_segments(norm_features)
probs = cls._get_probs(norm_features, gpu, mode=rl_mode)
keyframes = cls._get_keyframes(frames, probs, change_points, frames_per_segment)
chosen_frames = cls._get_popularity_chosen_frames(keyframes, features)
chosen_frames = cls._get_popularity_chosen_frames(keyframes, features, image_assessment_mode)
shutil.rmtree(jj(f"{settings.TMP_DIR}", f"{all_frames_tmp_dir}"))
return chosen_frames
@ -156,18 +159,20 @@ class KeyFramesExtractor:
return chosen_frames
@staticmethod
def _get_popularity_chosen_frames(frames, features, n_frames=10):
model_cache_key = "popularity_model_cache"
model = cache.get(model_cache_key) # get model from cache
if model is None:
model = PopularityPredictor()
cache.set(model_cache_key, model, None)
for frame in frames:
x = features[frame["index"]]
frame["popularity"] = model.get_popularity_score(x).squeeze()
def _get_popularity_chosen_frames(frames, features, image_assessment_mode=0, n_frames=10):
if image_assessment_mode == 1:
model_cache_key = "popularity_model_cache"
model = cache.get(model_cache_key) # get model from cache
if model is None:
model = PopularityPredictor()
cache.set(model_cache_key, model, None)
for frame in frames:
x = features[frame["index"]]
frame["popularity"] = model.get_popularity_score(x).squeeze()
else:
for frame in frames:
x = frame["frame"]
frame["popularity"] = nima_model.get_assessment_score(x)
chosen_frames = sorted(frames, key=lambda k: k['popularity'], reverse=True)
chosen_frames = chosen_frames[0:n_frames]
chosen_frames.sort(key=lambda k: k['index'])

View file

@ -0,0 +1,71 @@
import os
import errno
import numpy as np
from keras.models import load_model
from keras.preprocessing.image import load_img, img_to_array
from keras.applications.nasnet import preprocess_input
import tensorflow as tf
from PIL import Image
MODEL_PATH = 'neural_image_assessment/pretrained_model/nima_model.h5'
class NeuralImageAssessment:
def __init__(self):
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), MODEL_PATH)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.001
config.gpu_options.allow_growth = True
self.session = tf.Session(graph=self.graph, config=config)
with self.graph.as_default():
with tf.device('/CPU:0'):
self.model = load_model(MODEL_PATH)
@staticmethod
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
_PIL_INTERPOLATION_METHODS = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
}
img = Image.fromarray(np.uint8(bgr_img_array[..., ::-1]))
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
return img
def get_assessment_score(self, img_array):
with self.graph.as_default():
target_size = (224, 224)
img = NeuralImageAssessment.resize_image(img_array, target_size)
x = img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
mean = NeuralImageAssessment.mean_score(scores)
return mean
@staticmethod
def mean_score(scores):
si = np.arange(1, 11, 1)
mean = np.sum(scores * si)
return mean
@staticmethod
def std_score(scores):
si = np.arange(1, 11, 1)
mean = NeuralImageAssessment.mean_score(scores)
std = np.sqrt(np.sum(((si - mean) ** 2) * scores))
return std

Binary file not shown.

View file

@ -1,146 +1,148 @@
"""
Django settings for comixify project.
Generated by 'django-admin startproject' using Django 2.0.7.
For more information on this file, see
https://docs.djangoproject.com/en/2.0/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.0/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
# Use a separate file for the secret key
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
SECRET_KEY = f.read().strip()
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = os.environ.get('DEBUG') == 'true'
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'api',
'style_transfer',
'comic_layout',
'frontend',
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'settings.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
WSGI_APPLICATION = 'settings.wsgi.application'
# Database
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': 'postgres',
'USER': 'postgres',
'HOST': 'db',
'PORT': '5432',
'PASSWORD': 'postgres',
'CONN_MAX_AGE': 60,
}
}
# Password validation
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
CACHES = {
'default': {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
'LOCATION': 'unique-snowflake',
}
}
# Internationalization
# https://docs.djangoproject.com/en/2.0/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.0/howto/static-files/
STATIC_URL = '/static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
MAX_FILE_SIZE = 50000000
NUMBERS_OF_FRAMES_TO_SHOW = 10
TMP_DIR = 'tmp/'
GPU = True
FEATURE_BATCH_SIZE = 32
DEFAULT_FRAMES_SAMPLING_MODE = 0
DEFAULT_RL_MODE = 0
"""
Django settings for comixify project.
Generated by 'django-admin startproject' using Django 2.0.7.
For more information on this file, see
https://docs.djangoproject.com/en/2.0/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.0/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
# Use a separate file for the secret key
with open(os.path.join(BASE_DIR, 'secretkey.txt')) as f:
SECRET_KEY = f.read().strip()
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = os.environ.get('DEBUG') == 'true'
ALLOWED_HOSTS = ['35.241.250.34', 'comixify.ii.pw.edu.pl', 'localhost', '127.0.0.1']
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'api',
'style_transfer',
'comic_layout',
'frontend',
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'settings.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
WSGI_APPLICATION = 'settings.wsgi.application'
# Database
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': 'postgres',
'USER': 'postgres',
'HOST': 'db',
'PORT': '5432',
'PASSWORD': 'postgres',
'CONN_MAX_AGE': 60,
}
}
# Password validation
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
CACHES = {
'default': {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
'LOCATION': 'unique-snowflake',
}
}
# Internationalization
# https://docs.djangoproject.com/en/2.0/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.0/howto/static-files/
STATIC_URL = '/static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
PERMITTED_VIDEO_EXTENSIONS = ['mp4', 'avi']
MAX_FILE_SIZE = 50000000
NUMBERS_OF_FRAMES_TO_SHOW = 10
TMP_DIR = 'tmp/'
GPU = True
FEATURE_BATCH_SIZE = 32
DEFAULT_FRAMES_SAMPLING_MODE = 0
DEFAULT_RL_MODE = 0
DEFAULT_IMAGE_ASSESSMENT_MODE = 0
DEBUG = True