Fix tf session problem

This commit is contained in:
Maciej Pęśko 2018-11-10 13:01:35 +01:00
parent fb24d4b156
commit b2d059bc4d
5 changed files with 34 additions and 27 deletions

View file

@ -17,6 +17,9 @@ class ComixGAN:
config.gpu_options.allow_growth = True
self.session = tf.Session(graph=self.graph, config=config)
with self.graph.as_default():
with tf.device('/device:GPU:0'):
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
custom_objects={'InstanceNormalization': InstanceNormalization})
self.session = tf.Session()
with self.session.as_default():
with tf.device('/device:GPU:0'):
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
custom_objects={'InstanceNormalization': InstanceNormalization})

View file

@ -18,7 +18,7 @@ import logging
from utils import jj, profile
from keyframes_rl.models import DSN
from popularity.models import PopularityPredictor
from neural_image_assessment.models import NeuralImageAssessment
from neural_image_assessment.model import NeuralImageAssessment
from keyframes.kts import cpd_auto
from keyframes.utils import batch

View file

@ -1,27 +1,28 @@
import os
import errno
import os
import numpy as np
from keras.models import load_model
from keras.preprocessing.image import img_to_array
from keras.applications.nasnet import preprocess_input
import tensorflow as tf
from PIL import Image
MODEL_PATH = 'neural_image_assessment/pretrained_model/nima_model.h5'
from keras.applications.nasnet import preprocess_input
from keras.models import load_model
from keras.preprocessing.image import img_to_array
from django.conf import settings
class NeuralImageAssessment:
def __init__(self):
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), MODEL_PATH)
if not os.path.exists(settings.NIMA_MODEL_PATH):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.NIMA_MODEL_PATH)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.1
config.gpu_options.per_process_gpu_memory_fraction = 0.2
config.gpu_options.allow_growth = True
self.session = tf.Session(graph=self.graph, config=config)
with self.graph.as_default():
self.model = load_model(MODEL_PATH)
self.session = tf.Session()
with self.session.as_default():
self.model = load_model(settings.NIMA_MODEL_PATH)
@staticmethod
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
@ -46,13 +47,14 @@ class NeuralImageAssessment:
def get_assessment_score(self, img_array):
with self.graph.as_default():
target_size = (224, 224)
img = NeuralImageAssessment.resize_image(img_array, target_size)
x = img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
with self.session.as_default():
target_size = (224, 224)
img = NeuralImageAssessment.resize_image(img_array, target_size)
x = img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
mean = NeuralImageAssessment.mean_score(scores)
return mean

View file

@ -148,3 +148,4 @@ DEFAULT_IMAGE_ASSESSMENT_MODE = 0
DEFAULT_STYLE_TRANSFER_MODE = 0
COMIX_GAN_MODEL_PATH = os.path.join(BASE_DIR, 'ComixGAN', 'pretrained_models', 'generator_model.h5')
NIMA_MODEL_PATH = os.path.join(BASE_DIR, 'neural_image_assessment', 'pretrained_model', 'nima_model.h5')

View file

@ -52,12 +52,13 @@ class StyleTransfer():
frames = cls._resize_images(frames, size=450)
with comixGAN.graph.as_default():
batch_size = 1
stylized_imgs = []
for i in range(0, len(frames), batch_size):
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 2))
with comixGAN.session.as_default():
batch_size = 1
stylized_imgs = []
for i in range(0, len(frames), batch_size):
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 2))
# K.clear_session()
# gc.collect()
return list(np.concatenate(stylized_imgs, axis=0))