mirror of
https://codeberg.org/Freedium-cfd/web.git
synced 2026-03-11 09:04:37 +00:00
379 lines
13 KiB
Python
379 lines
13 KiB
Python
import orjson as json
|
|
import time
|
|
import sqlite3
|
|
import threading
|
|
from itertools import islice
|
|
from typing import Union, Optional
|
|
from abc import ABC, abstractmethod
|
|
|
|
import psycopg2
|
|
import uuid
|
|
from loguru import logger
|
|
from psycopg2.extras import execute_batch
|
|
|
|
try:
|
|
import sqlite_zstd
|
|
except ImportError:
|
|
sqlite_zstd = None
|
|
|
|
|
|
class CacheData:
|
|
__slots__ = ("data",)
|
|
|
|
def __init__(self, data: str):
|
|
self.data = data
|
|
|
|
def json(self):
|
|
return json.loads(self.data)
|
|
|
|
def __repr__(self):
|
|
return self.data
|
|
|
|
def __str__(self):
|
|
return self.data
|
|
|
|
def has_data(self):
|
|
return self.data is not None and self.data != ""
|
|
|
|
class CacheResponse:
|
|
__slots__ = ("key", "data")
|
|
|
|
def __init__(self, key: str, data: Union[CacheData, str]):
|
|
self.key: str = key
|
|
self.data: CacheData = CacheData(data) if isinstance(data, str) else data
|
|
|
|
def json(self):
|
|
return self.data.json()
|
|
|
|
|
|
class AbstractCacheBackend(ABC):
|
|
@abstractmethod
|
|
def init_db(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def all(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def all_length(self) -> int:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def random(self, size: int) -> list[CacheResponse]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def pull(self, key: str) -> Union[CacheResponse, None]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def push(self, key: str, value: Union[str, dict]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def delete(self, key: str) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
class SQLiteCacheBackend(AbstractCacheBackend):
|
|
__slots__ = ("connection", "cursor", "database", "lock", "zstd_enabled")
|
|
|
|
def __init__(self, database: str, zstd_enabled: bool = False):
|
|
self.database = database
|
|
self.connection = None
|
|
self.cursor = None
|
|
self.lock = threading.Lock()
|
|
self.zstd_enabled = zstd_enabled
|
|
self.connect()
|
|
|
|
def connect(self):
|
|
self.connection = sqlite3.connect(self.database, timeout=10.0)
|
|
self.connection.enable_load_extension(True)
|
|
self.connection.execute("PRAGMA foreign_keys = ON;")
|
|
self.connection.execute("PRAGMA journal_mode=WAL;")
|
|
self.connection.execute("PRAGMA auto_vacuum=full;")
|
|
self.cursor = self.connection.cursor()
|
|
|
|
if self.zstd_enabled:
|
|
if sqlite_zstd is None:
|
|
raise ValueError("sqlite_zstd library not found.")
|
|
|
|
sqlite_zstd.load(self.connection)
|
|
self.enable_zstd()
|
|
|
|
def ensure_connection(self):
|
|
if self.connection is None or self.cursor is None:
|
|
self.connect()
|
|
|
|
def all(self):
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
return self.cursor.execute("SELECT * FROM cache").fetchall()
|
|
|
|
def all_length(self) -> int:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
return self.cursor.execute("SELECT COUNT(*) FROM cache").fetchone()[0]
|
|
|
|
def random(self, size: int) -> list[CacheResponse]:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("SELECT key, value FROM cache ORDER BY RANDOM() LIMIT ?", (size,))
|
|
return [CacheResponse(key, value) for key, value in self.cursor]
|
|
|
|
def enable_zstd(self):
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
try:
|
|
self.cursor.execute('SELECT zstd_enable_transparent(\'{"table": "cache", "column": "value", "compression_level": 9, "dict_chooser": "\'\'a\'\'"}\')')
|
|
except Exception as error:
|
|
logger.error(f"Error enabling ZSTD compression: {error}")
|
|
logger.exception(error)
|
|
|
|
self.connection.execute("PRAGMA auto_vacuum=full")
|
|
|
|
def init_db(self):
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("CREATE TABLE IF NOT EXISTS cache (key TEXT PRIMARY KEY, value TEXT)")
|
|
# self.cursor.execute("CREATE INDEX IF NOT EXISTS idx_key ON cache (key)")
|
|
|
|
def pull(self, key: str) -> Union[CacheResponse, None]:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
cache = self.cursor.execute("SELECT value FROM cache WHERE key = :0", {"0": key}).fetchone()
|
|
if cache:
|
|
logger.debug("Value found in DB, returning it")
|
|
return CacheResponse(key, cache[0])
|
|
else:
|
|
logger.debug(f"No value found for key: {key}")
|
|
return None
|
|
|
|
def push(self, key: str, value: Union[str, dict]) -> None:
|
|
if isinstance(value, dict):
|
|
try:
|
|
value = json.dumps(value)
|
|
except TypeError as e:
|
|
raise ValueError(f"Unable to serialize value to JSON: {e}")
|
|
elif not isinstance(value, str):
|
|
raise ValueError(f"value argument should be a string or dict, not {type(value).__name__}")
|
|
|
|
self.ensure_connection()
|
|
with self.lock:
|
|
with self.connection:
|
|
self.cursor.execute("INSERT OR REPLACE INTO cache VALUES (:0, :1)", {"0": key, "1": value})
|
|
|
|
def delete(self, key: str) -> None:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
result = self.cursor.execute("SELECT 1 FROM cache WHERE key = :0", {"0": key}).fetchone()
|
|
if result:
|
|
self.cursor.execute("DELETE FROM cache WHERE key = :0", {"0": key})
|
|
logger.debug(f"Deleted key: {key}")
|
|
else:
|
|
logger.debug(f"Attempted to delete non-existing key: {key}")
|
|
|
|
def _generate_test_data(self, num_rows: int, batch_size: int = 10000):
|
|
logger.info("Generating test data")
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
# Fetch a random key-value pair just once
|
|
self.cursor.execute("SELECT key, value FROM cache ORDER BY RANDOM() LIMIT 1")
|
|
key, value = self.cursor.fetchone()
|
|
|
|
# Use a generator to create batches of test data
|
|
def data_generator():
|
|
for i in range(num_rows):
|
|
yield (f"{uuid.uuid4()}", value)
|
|
|
|
# Process data in batches
|
|
for offset in range(0, num_rows, batch_size):
|
|
batch = list(islice(data_generator(), batch_size))
|
|
self.cursor.executemany("INSERT INTO cache VALUES (?, ?)", batch)
|
|
self.connection.commit()
|
|
print(f"Inserted {min(offset + batch_size, num_rows)} / {num_rows} rows")
|
|
|
|
def maintenance(self, time: Optional[int] = None, blocking_time: float = 0.5):
|
|
connection = sqlite3.connect(self.database)
|
|
cursor = connection.cursor()
|
|
connection.enable_load_extension(True) # Enable loading of extensions
|
|
connection.execute("PRAGMA foreign_keys = ON;") # Need for working with foreign keys in db
|
|
connection.execute("PRAGMA journal_mode=WAL;") # Need to properly work with ZSTD compression
|
|
connection.execute("PRAGMA auto_vacuum=full;") # Same as above thing
|
|
|
|
if sqlite_zstd is not None:
|
|
sqlite_zstd.load(connection)
|
|
|
|
with connection:
|
|
if time is not None:
|
|
cursor.execute("SELECT zstd_incremental_maintenance(?, ?);", (time, blocking_time))
|
|
else:
|
|
cursor.execute("SELECT zstd_incremental_maintenance(null, ?);", (blocking_time,))
|
|
cursor.execute("VACUUM")
|
|
cursor.execute("ANALYZE")
|
|
|
|
cursor.close()
|
|
connection.close()
|
|
|
|
def maintenance_thread(self):
|
|
maintenance_thread = threading.Thread(target=self.maintenance, daemon=True)
|
|
maintenance_thread.start()
|
|
|
|
def show_schema_info(self):
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
return self.connection.execute("SELECT sql FROM sqlite_master").fetchall()
|
|
|
|
def close(self):
|
|
if self.connection:
|
|
self.connection.close()
|
|
self.connection = None
|
|
self.cursor = None
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
|
|
class PostgreSQLCacheBackend(AbstractCacheBackend):
|
|
def __init__(self, connection_string: str):
|
|
self.connection_string = connection_string
|
|
self.connection = None
|
|
self.cursor = None
|
|
self.connect()
|
|
|
|
def connect(self):
|
|
self.connection = psycopg2.connect(self.connection_string)
|
|
self.cursor = self.connection.cursor()
|
|
|
|
def ensure_connection(self):
|
|
if self.connection is None or self.connection.closed:
|
|
self.connect()
|
|
elif self.cursor is None or self.cursor.closed:
|
|
self.cursor = self.connection.cursor()
|
|
|
|
def init_db(self):
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS cache (
|
|
key TEXT PRIMARY KEY,
|
|
value TEXT
|
|
)
|
|
"""
|
|
)
|
|
|
|
def all(self):
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("SELECT * FROM cache")
|
|
return self.cursor.fetchall()
|
|
|
|
def all_length(self) -> int:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("SELECT COUNT(*) FROM cache")
|
|
return self.cursor.fetchone()[0]
|
|
|
|
def random(self, size: int) -> list[CacheResponse]:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("SELECT key, value FROM cache ORDER BY RANDOM() LIMIT %s", (size,))
|
|
return [CacheResponse(key, value) for key, value in self.cursor]
|
|
|
|
def pull(self, key: str) -> Union[CacheResponse, None]:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("SELECT value FROM cache WHERE key = %s", (key,))
|
|
cache = self.cursor.fetchone()
|
|
if cache:
|
|
logger.debug("Value found in DB, returning it")
|
|
return CacheResponse(key, cache[0])
|
|
else:
|
|
logger.debug(f"No value found for key: {key}")
|
|
return None
|
|
|
|
def push(self, key: str, value: Union[str, dict]) -> None:
|
|
if isinstance(value, dict):
|
|
try:
|
|
value = json.dumps(value)
|
|
except TypeError as e:
|
|
raise ValueError(f"Unable to serialize value to JSON: {e}")
|
|
elif not isinstance(value, str):
|
|
raise ValueError(f"value argument should be a string or dict, not {type(value).__name__}")
|
|
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("INSERT INTO cache (key, value) VALUES (%s, %s) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value", (key, value))
|
|
|
|
def delete(self, key: str) -> None:
|
|
self.ensure_connection()
|
|
with self.connection:
|
|
self.cursor.execute("DELETE FROM cache WHERE key = %s", (key,))
|
|
if self.cursor.rowcount > 0:
|
|
logger.debug(f"Deleted key: {key}")
|
|
else:
|
|
logger.debug(f"Attempted to delete non-existing key: {key}")
|
|
|
|
def close(self):
|
|
if self.cursor:
|
|
self.cursor.close()
|
|
if self.connection:
|
|
self.connection.close()
|
|
self.cursor = None
|
|
self.connection = None
|
|
|
|
|
|
def migrate_to_postgres(sqlite_db_path: str, pg_conn_string: str, chunk_size: int = 1000):
|
|
logger.debug(f"Starting migration from SQLite ({sqlite_db_path}) to PostgreSQL")
|
|
sqlite_db = SQLiteCacheBackend(sqlite_db_path, zstd_enabled=True)
|
|
pg_db = PostgreSQLCacheBackend(pg_conn_string)
|
|
|
|
# sqlite_db._generate_test_data(15_000_000_000)
|
|
pg_db.init_db()
|
|
logger.debug("PostgreSQL database initialized")
|
|
|
|
total_rows = sqlite_db.all_length()
|
|
logger.info(f"Total rows to migrate: {total_rows}")
|
|
processed_rows = 0
|
|
start_time = time.time()
|
|
|
|
try:
|
|
while processed_rows < total_rows:
|
|
chunk = sqlite_db.cursor.execute("SELECT key, value FROM cache LIMIT ? OFFSET ?", (chunk_size, processed_rows)).fetchall()
|
|
if not chunk:
|
|
logger.debug("No more rows to process")
|
|
break
|
|
|
|
logger.debug(f"Processing chunk of {len(chunk)} rows")
|
|
execute_batch(pg_db.cursor, "INSERT INTO cache (key, value) VALUES (%s, %s) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value", chunk)
|
|
pg_db.connection.commit()
|
|
|
|
processed_rows += len(chunk)
|
|
|
|
elapsed_time = time.time() - start_time
|
|
rows_per_second = processed_rows / elapsed_time
|
|
logger.info(f"Processed {processed_rows}/{total_rows} rows. Speed: {rows_per_second:.2f} rows/second")
|
|
except Exception as e:
|
|
logger.error(f"An error occurred during migration: {e}")
|
|
pg_db.connection.rollback()
|
|
finally:
|
|
sqlite_db.close()
|
|
pg_db.close()
|
|
logger.debug("Database connections closed")
|
|
|
|
total_time = time.time() - start_time
|
|
logger.success(f"Data migration to PostgreSQL completed")
|
|
logger.info(f"Total time: {total_time:.2f} seconds. Average speed: {total_rows/total_time:.2f} rows/second")
|
|
|
|
|
|
def execute_migrate_to_postgres_in_thread(sqlite_db_path: str, pg_conn_string: str, chunk_size: int = 1000):
|
|
logger.info("Starting migration to PostgreSQL in thread")
|
|
migration_thread = threading.Thread(target=migrate_to_postgres, args=(sqlite_db_path, pg_conn_string, chunk_size), daemon=True)
|
|
migration_thread.start()
|
|
return migration_thread
|