diff --git a/database-lib/database_lib/main.py b/database-lib/database_lib/main.py index b8127f7..cb26330 100644 --- a/database-lib/database_lib/main.py +++ b/database-lib/database_lib/main.py @@ -75,31 +75,43 @@ class AbstractCacheBackend(ABC): class SQLiteCacheBackend(AbstractCacheBackend): - __slots__ = ("connection", "cursor", "database", "lock") + __slots__ = ("connection", "cursor", "database", "lock", "zstd_enabled") def __init__(self, database: str, zstd_enabled: bool = False): self.database = database - self.connection = sqlite3.connect(database, timeout=10.0) - self.connection.enable_load_extension(True) # Enable loading of extensions - self.connection.execute("PRAGMA foreign_keys = ON;") # Need for working with foreign keys in db - self.connection.execute("PRAGMA journal_mode=WAL;") # Need to properly work with ZSTD compression - self.connection.execute("PRAGMA auto_vacuum=full;") # Same as above thing - self.cursor = self.connection.cursor() + self.connection = None + self.cursor = None self.lock = threading.Lock() + self.zstd_enabled = zstd_enabled + self.connect() - if zstd_enabled: + def connect(self): + self.connection = sqlite3.connect(self.database, timeout=10.0) + self.connection.enable_load_extension(True) + self.connection.execute("PRAGMA foreign_keys = ON;") + self.connection.execute("PRAGMA journal_mode=WAL;") + self.connection.execute("PRAGMA auto_vacuum=full;") + self.cursor = self.connection.cursor() + + if self.zstd_enabled: sqlite_zstd.load(self.connection) - # self.enable_zstd() + + def ensure_connection(self): + if self.connection is None or self.cursor is None: + self.connect() def all(self): + self.ensure_connection() with self.connection: return self.cursor.execute("SELECT * FROM cache").fetchall() def all_length(self) -> int: + self.ensure_connection() with self.connection: return self.cursor.execute("SELECT COUNT(*) FROM cache").fetchone()[0] def random(self, size: int) -> list[CacheResponse]: + self.ensure_connection() with self.connection: self.cursor.execute("SELECT key, value FROM cache ORDER BY RANDOM() LIMIT ?", (size,)) return [CacheResponse(key, value) for key, value in self.cursor] @@ -108,6 +120,7 @@ class SQLiteCacheBackend(AbstractCacheBackend): if not self.zstd_enabled: raise ValueError("Can't use zstd compression. Please install 'sqlite_zstd' package") + self.ensure_connection() with self.connection: try: self.cursor.execute('SELECT zstd_enable_transparent(\'{"table": "cache", "column": "value", "compression_level": 9, "dict_chooser": "\'\'a\'\'"}\')') @@ -118,11 +131,13 @@ class SQLiteCacheBackend(AbstractCacheBackend): self.connection.execute("PRAGMA auto_vacuum=full") def init_db(self): + self.ensure_connection() with self.connection: self.cursor.execute("CREATE TABLE IF NOT EXISTS cache (key TEXT PRIMARY KEY, value TEXT)") # self.cursor.execute("CREATE INDEX IF NOT EXISTS idx_key ON cache (key)") def pull(self, key: str) -> Union[CacheResponse, None]: + self.ensure_connection() with self.connection: cache = self.cursor.execute("SELECT value FROM cache WHERE key = :0", {"0": key}).fetchone() if cache: @@ -141,11 +156,13 @@ class SQLiteCacheBackend(AbstractCacheBackend): elif not isinstance(value, str): raise ValueError(f"value argument should be a string or dict, not {type(value).__name__}") + self.ensure_connection() with self.lock: with self.connection: self.cursor.execute("INSERT OR REPLACE INTO cache VALUES (:0, :1)", {"0": key, "1": value}) def delete(self, key: str) -> None: + self.ensure_connection() with self.connection: result = self.cursor.execute("SELECT 1 FROM cache WHERE key = :0", {"0": key}).fetchone() if result: @@ -156,6 +173,7 @@ class SQLiteCacheBackend(AbstractCacheBackend): def _generate_test_data(self, num_rows: int, batch_size: int = 10000): logger.info("Generating test data") + self.ensure_connection() with self.connection: # Fetch a random key-value pair just once self.cursor.execute("SELECT key, value FROM cache ORDER BY RANDOM() LIMIT 1") @@ -200,22 +218,39 @@ class SQLiteCacheBackend(AbstractCacheBackend): maintenance_thread.start() def show_schema_info(self): + self.ensure_connection() with self.connection: return self.connection.execute("SELECT sql FROM sqlite_master").fetchall() def close(self): - self.__del__() + if self.connection: + self.connection.close() + self.connection = None + self.cursor = None - def __del__(self) -> None: - self.connection.close() + def __del__(self): + self.close() class PostgreSQLCacheBackend(AbstractCacheBackend): def __init__(self, connection_string: str): - self.connection = psycopg2.connect(connection_string) + self.connection_string = connection_string + self.connection = None + self.cursor = None + self.connect() + + def connect(self): + self.connection = psycopg2.connect(self.connection_string) self.cursor = self.connection.cursor() + def ensure_connection(self): + if self.connection is None or self.connection.closed: + self.connect() + elif self.cursor is None or self.cursor.closed: + self.cursor = self.connection.cursor() + def init_db(self): + self.ensure_connection() with self.connection: self.cursor.execute( """ @@ -227,21 +262,25 @@ class PostgreSQLCacheBackend(AbstractCacheBackend): ) def all(self): + self.ensure_connection() with self.connection: self.cursor.execute("SELECT * FROM cache") return self.cursor.fetchall() def all_length(self) -> int: + self.ensure_connection() with self.connection: self.cursor.execute("SELECT COUNT(*) FROM cache") return self.cursor.fetchone()[0] def random(self, size: int) -> list[CacheResponse]: + self.ensure_connection() with self.connection: self.cursor.execute("SELECT key, value FROM cache ORDER BY RANDOM() LIMIT %s", (size,)) return [CacheResponse(key, value) for key, value in self.cursor] def pull(self, key: str) -> Union[CacheResponse, None]: + self.ensure_connection() with self.connection: self.cursor.execute("SELECT value FROM cache WHERE key = %s", (key,)) cache = self.cursor.fetchone() @@ -261,10 +300,12 @@ class PostgreSQLCacheBackend(AbstractCacheBackend): elif not isinstance(value, str): raise ValueError(f"value argument should be a string or dict, not {type(value).__name__}") + self.ensure_connection() with self.connection: self.cursor.execute("INSERT INTO cache (key, value) VALUES (%s, %s) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value", (key, value)) def delete(self, key: str) -> None: + self.ensure_connection() with self.connection: self.cursor.execute("DELETE FROM cache WHERE key = %s", (key,)) if self.cursor.rowcount > 0: @@ -273,8 +314,12 @@ class PostgreSQLCacheBackend(AbstractCacheBackend): logger.debug(f"Attempted to delete non-existing key: {key}") def close(self): - self.cursor.close() - self.connection.close() + if self.cursor: + self.cursor.close() + if self.connection: + self.connection.close() + self.cursor = None + self.connection = None def migrate_to_postgres(sqlite_db_path: str, pg_conn_string: str, chunk_size: int = 1000):