123 lines
3.4 KiB
Python
123 lines
3.4 KiB
Python
import logging
|
|
import sqlite3
|
|
|
|
import aiosqlite
|
|
|
|
from .migrator import dumb_migrate_db
|
|
from .utils import contains
|
|
|
|
log = logging.getLogger("nomen.db")
|
|
log.setLevel(logging.INFO)
|
|
|
|
schema = """
|
|
PRAGMA user_version = 1;
|
|
|
|
CREATE TABLE keywords (
|
|
guild_id INTEGER NOT NULL,
|
|
keyword TEXT NOT NULL,
|
|
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
|
|
regex INTEGER NOT NULL DEFAULT 0 CHECK(regex IN (0, 1)),
|
|
count INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
|
|
CREATE TABLE guilds (
|
|
guild_id INTEGER NOT NULL PRIMARY KEY,
|
|
prefix TEXT NOT NULL DEFAULT ">"
|
|
)
|
|
WITHOUT ROWID;
|
|
|
|
CREATE TABLE users (
|
|
user_id INTEGER NOT NULL PRIMARY KEY,
|
|
disabled INTEGER NOT NULL DEFAULT 0 CHECK(disabled IN (0, 1)),
|
|
use_embed INTEGER NOT NULL DEFAULT 1 CHECK(use_embed IN (0, 1)),
|
|
notify_self INTEGER NOT NULL DEFAULT 0 CHECK(notify_self IN (0, 1)),
|
|
bots_notify INTEGER NOT NULL DEFAULT 0 CHECK(bots_notify IN (0, 1)),
|
|
ignore_active INTEGER NOT NULL DEFAULT 0 CHECK(ignore_active IN (0, 1))
|
|
)
|
|
WITHOUT ROWID;
|
|
|
|
CREATE TABLE user_ignores (
|
|
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
|
|
guild_id INTEGER NOT NULL,
|
|
target INTEGER NOT NULL, -- channel or user id
|
|
PRIMARY KEY (user_id, guild_id, target)
|
|
);
|
|
|
|
CREATE TABLE user_blocks (
|
|
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
|
|
target INTEGER NOT NULL,
|
|
PRIMARY KEY (user_id, target)
|
|
);
|
|
|
|
CREATE TABLE user_pauses (
|
|
user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE,
|
|
guild_id INTEGER NOT NULL,
|
|
PRIMARY KEY (user_id, guild_id)
|
|
);
|
|
|
|
CREATE INDEX keywords_index ON keywords(user_id);
|
|
CREATE INDEX user_ignores_index ON user_ignores(user_id);
|
|
CREATE INDEX user_blocks_index ON user_blocks(user_id);
|
|
CREATE INDEX user_pauses_index ON user_pauses(user_id);
|
|
"""
|
|
|
|
|
|
def run_db_migrations(db_file):
|
|
log.debug(f"Running automatic migration on {db_file}")
|
|
with sqlite3.connect(db_file) as db:
|
|
if dumb_migrate_db(db, schema):
|
|
log.info(f"Migrated {db_file}")
|
|
log.debug("Finished running automatic migration")
|
|
|
|
|
|
async def fetch_singleton(db, sql, params=None):
|
|
"""
|
|
Fetch an object from the database, with the assumption that the result is 1 row by 1 column
|
|
"""
|
|
|
|
result = await db.execute_fetchall(f"{sql} LIMIT 1", params)
|
|
return result[0][0]
|
|
|
|
|
|
async def fetch_exists(db, sql, params=None):
|
|
return await fetch_singleton(db, f"SELECT EXISTS({sql})", params)
|
|
|
|
|
|
async def fetch_unpacked(db, sql, params=None):
|
|
cur = await db.cursor()
|
|
cur.row_factory = lambda cursor, row: row[0]
|
|
cur = await cur.execute(sql, params)
|
|
return await cur.fetchall()
|
|
|
|
|
|
log.debug("Monkeypatching in helpers")
|
|
aiosqlite.Connection.fetch_singleton = fetch_singleton
|
|
aiosqlite.Connection.fetch_exists = fetch_exists
|
|
aiosqlite.Connection.fetch_unpacked = fetch_unpacked
|
|
|
|
|
|
class Row(sqlite3.Row):
|
|
def __repr__(self):
|
|
return f"Row<{repr(dict(self))}>"
|
|
|
|
|
|
async def setup_db(db_file):
|
|
log.debug(f"Connecting to {db_file}")
|
|
db = await aiosqlite.connect(db_file)
|
|
|
|
log.debug("Running start script")
|
|
await db.executescript("""
|
|
PRAGMA optimize(0x10002);
|
|
PRAGMA main.synchronous = NORMAL;
|
|
PRAGMA foreign_keys = ON;
|
|
""")
|
|
|
|
log.debug("Setting row factory")
|
|
db.row_factory = Row
|
|
|
|
log.debug("Adding contains function")
|
|
await db.create_function("contains", 3, contains, deterministic=True)
|
|
|
|
log.debug("Done setting up DB")
|
|
|
|
return db
|