mirror of
https://github.com/maciej3031/comixify.git
synced 2026-03-11 08:54:35 +00:00
Fix tf session problem
This commit is contained in:
parent
fb24d4b156
commit
b2d059bc4d
5 changed files with 34 additions and 27 deletions
|
|
@ -17,6 +17,9 @@ class ComixGAN:
|
|||
config.gpu_options.allow_growth = True
|
||||
self.session = tf.Session(graph=self.graph, config=config)
|
||||
with self.graph.as_default():
|
||||
with tf.device('/device:GPU:0'):
|
||||
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
|
||||
custom_objects={'InstanceNormalization': InstanceNormalization})
|
||||
self.session = tf.Session()
|
||||
with self.session.as_default():
|
||||
with tf.device('/device:GPU:0'):
|
||||
self.model = load_model(settings.COMIX_GAN_MODEL_PATH,
|
||||
custom_objects={'InstanceNormalization': InstanceNormalization})
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import logging
|
|||
from utils import jj, profile
|
||||
from keyframes_rl.models import DSN
|
||||
from popularity.models import PopularityPredictor
|
||||
from neural_image_assessment.models import NeuralImageAssessment
|
||||
from neural_image_assessment.model import NeuralImageAssessment
|
||||
from keyframes.kts import cpd_auto
|
||||
from keyframes.utils import batch
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,28 @@
|
|||
import os
|
||||
import errno
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from keras.models import load_model
|
||||
from keras.preprocessing.image import img_to_array
|
||||
from keras.applications.nasnet import preprocess_input
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
|
||||
|
||||
MODEL_PATH = 'neural_image_assessment/pretrained_model/nima_model.h5'
|
||||
from keras.applications.nasnet import preprocess_input
|
||||
from keras.models import load_model
|
||||
from keras.preprocessing.image import img_to_array
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class NeuralImageAssessment:
|
||||
def __init__(self):
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), MODEL_PATH)
|
||||
if not os.path.exists(settings.NIMA_MODEL_PATH):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), settings.NIMA_MODEL_PATH)
|
||||
self.graph = tf.Graph()
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.1
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.2
|
||||
config.gpu_options.allow_growth = True
|
||||
self.session = tf.Session(graph=self.graph, config=config)
|
||||
with self.graph.as_default():
|
||||
self.model = load_model(MODEL_PATH)
|
||||
self.session = tf.Session()
|
||||
with self.session.as_default():
|
||||
self.model = load_model(settings.NIMA_MODEL_PATH)
|
||||
|
||||
@staticmethod
|
||||
def resize_image(bgr_img_array, target_size=(224, 224), interpolation='nearest'):
|
||||
|
|
@ -46,13 +47,14 @@ class NeuralImageAssessment:
|
|||
|
||||
def get_assessment_score(self, img_array):
|
||||
with self.graph.as_default():
|
||||
target_size = (224, 224)
|
||||
img = NeuralImageAssessment.resize_image(img_array, target_size)
|
||||
x = img_to_array(img)
|
||||
x = np.expand_dims(x, axis=0)
|
||||
x = preprocess_input(x)
|
||||
with self.session.as_default():
|
||||
target_size = (224, 224)
|
||||
img = NeuralImageAssessment.resize_image(img_array, target_size)
|
||||
x = img_to_array(img)
|
||||
x = np.expand_dims(x, axis=0)
|
||||
x = preprocess_input(x)
|
||||
|
||||
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
|
||||
scores = self.model.predict(x, batch_size=1, verbose=0)[0]
|
||||
mean = NeuralImageAssessment.mean_score(scores)
|
||||
|
||||
return mean
|
||||
|
|
@ -148,3 +148,4 @@ DEFAULT_IMAGE_ASSESSMENT_MODE = 0
|
|||
|
||||
DEFAULT_STYLE_TRANSFER_MODE = 0
|
||||
COMIX_GAN_MODEL_PATH = os.path.join(BASE_DIR, 'ComixGAN', 'pretrained_models', 'generator_model.h5')
|
||||
NIMA_MODEL_PATH = os.path.join(BASE_DIR, 'neural_image_assessment', 'pretrained_model', 'nima_model.h5')
|
||||
|
|
@ -52,12 +52,13 @@ class StyleTransfer():
|
|||
frames = cls._resize_images(frames, size=450)
|
||||
|
||||
with comixGAN.graph.as_default():
|
||||
batch_size = 1
|
||||
stylized_imgs = []
|
||||
for i in range(0, len(frames), batch_size):
|
||||
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
|
||||
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
|
||||
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 2))
|
||||
with comixGAN.session.as_default():
|
||||
batch_size = 1
|
||||
stylized_imgs = []
|
||||
for i in range(0, len(frames), batch_size):
|
||||
batch_of_frames = ((np.stack(frames[i:i + batch_size]) / 255) * 2) - 1
|
||||
stylized_batch_of_imgs = comixGAN.model.predict(batch_of_frames)
|
||||
stylized_imgs.append(255 * ((stylized_batch_of_imgs + 1) / 2))
|
||||
# K.clear_session()
|
||||
# gc.collect()
|
||||
return list(np.concatenate(stylized_imgs, axis=0))
|
||||
|
|
|
|||
Loading…
Reference in a new issue