Add profiling (#14)

This commit is contained in:
Adam Svystun 2018-11-09 12:19:31 +01:00 committed by GitHub
parent d453748661
commit 244a7deb53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 9 deletions

View file

@ -35,7 +35,7 @@ class Video(models.Model):
raise TooLargeFile()
def create_comic(self, frames_mode=0, rl_mode=0, image_assessment_mode=0):
keyframes, keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
(keyframes, keyframes_timings), keyframes_extraction_time = KeyFramesExtractor.get_keyframes(
video=self,
frames_mode=frames_mode,
rl_mode=rl_mode,
@ -44,9 +44,12 @@ class Video(models.Model):
stylized_keyframes, stylization_time = StyleTransfer.get_stylized_frames(frames=keyframes)
comic_image, layout_generation_time = LayoutGenerator.get_layout(frames=stylized_keyframes)
timings = {'keyframes_extraction_time': keyframes_extraction_time,
'stylization_time': stylization_time,
'layout_generation_time': layout_generation_time}
timings = {
'keyframes_extraction_time': keyframes_extraction_time,
'stylization_time': stylization_time,
'layout_generation_time': layout_generation_time,
**keyframes_timings
}
return comic_image, timings

View file

@ -32,18 +32,24 @@ class KeyFramesExtractor:
@profile
def get_keyframes(cls, video, gpu=settings.GPU, features_batch_size=settings.FEATURE_BATCH_SIZE,
frames_mode=0, rl_mode=0, image_assessment_mode=0):
frames_paths, all_frames_tmp_dir = cls._get_all_frames(video, mode=frames_mode)
timings = {}
(frames_paths, all_frames_tmp_dir), get_frames_timing = cls._get_all_frames(video, mode=frames_mode)
timings["get_frames_time"] = get_frames_timing
frames = cls._get_frames(frames_paths)
features = cls._get_features(frames, gpu, features_batch_size)
features, get_features_time = cls._get_features(frames, gpu, features_batch_size)
timings["get_features_time"] = get_features_time
norm_features = normalize(features)
change_points, frames_per_segment = cls._get_segments(norm_features)
probs = cls._get_probs(norm_features, gpu, mode=rl_mode)
probs, highlightness_time = cls._get_probs(norm_features, gpu, mode=rl_mode)
timings["highlightness_time"] = highlightness_time
keyframes = cls._get_keyframes(frames, probs, change_points, frames_per_segment)
chosen_frames = cls._get_popularity_chosen_frames(keyframes, features, image_assessment_mode)
chosen_frames, second_filtering_time = cls._get_popularity_chosen_frames(keyframes, features, image_assessment_mode)
timings["second_filtering_time"] = second_filtering_time
shutil.rmtree(jj(f"{settings.TMP_DIR}", f"{all_frames_tmp_dir}"))
return chosen_frames
return chosen_frames, timings
@staticmethod
@profile
def _get_all_frames(video, mode=0):
all_frames_tmp_dir = uuid.uuid4().hex
os.mkdir(jj(settings.TMP_DIR, all_frames_tmp_dir))
@ -72,6 +78,7 @@ class KeyFramesExtractor:
return frames
@staticmethod
@profile
def _get_features(frames, gpu=True, batch_size=1):
caffe_root = os.environ.get("CAFFE_ROOT")
if not caffe_root:
@ -108,6 +115,7 @@ class KeyFramesExtractor:
return features.astype(np.float32)
@staticmethod
@profile
def _get_probs(features, gpu=True, mode=0):
model_cache_key = "keyframes_rl_model_cache_" + str(mode)
model = cache.get(model_cache_key) # get model from cache
@ -159,6 +167,7 @@ class KeyFramesExtractor:
return chosen_frames
@staticmethod
@profile
def _get_popularity_chosen_frames(frames, features, image_assessment_mode=0, n_frames=10):
if image_assessment_mode == 1:
model_cache_key = "popularity_model_cache"