mirror of
https://github.com/maciej3031/comixify.git
synced 2026-03-11 08:54:35 +00:00
17 lines
504 B
Python
17 lines
504 B
Python
|
|
import pickle
|
||
|
|
import os.path
|
||
|
|
|
||
|
|
MODEL_PATH = 'popularity/pretrained_model/svr_test_11.10.sk'
|
||
|
|
|
||
|
|
|
||
|
|
class PopularityPredictor:
|
||
|
|
def __init__(self):
|
||
|
|
if not os.path.exists(MODEL_PATH):
|
||
|
|
print("Model file does not exist.")
|
||
|
|
with open(MODEL_PATH, 'rb') as fp:
|
||
|
|
self.svr = pickle.load(fp, encoding='latin1')
|
||
|
|
|
||
|
|
def get_popularity_score(self, image_feature):
|
||
|
|
image_feature = image_feature.reshape(1, -1)
|
||
|
|
return self.svr.predict(image_feature)
|