Compare commits

..

8 commits

4 changed files with 95 additions and 85 deletions

View file

@ -31,7 +31,7 @@ WITHOUT ROWID;
CREATE TABLE users ( CREATE TABLE users (
user_id INTEGER NOT NULL PRIMARY KEY, user_id INTEGER NOT NULL PRIMARY KEY,
disabled INTEGER NOT NULL DEFAULT 0 CHECK(disabled IN (0, 1)), disabled INTEGER NOT NULL DEFAULT 0 CHECK(disabled IN (0, 1)),
use_embeds INTEGER NOT NULL DEFAULT 1 CHECK(use_embeds IN (0, 1)), use_embed INTEGER NOT NULL DEFAULT 1 CHECK(use_embed IN (0, 1)),
notify_self INTEGER NOT NULL DEFAULT 0 CHECK(notify_self IN (0, 1)), notify_self INTEGER NOT NULL DEFAULT 0 CHECK(notify_self IN (0, 1)),
bots_notify INTEGER NOT NULL DEFAULT 0 CHECK(bots_notify IN (0, 1)) bots_notify INTEGER NOT NULL DEFAULT 0 CHECK(bots_notify IN (0, 1))
) )

View file

@ -1,7 +1,6 @@
import io import io
import logging import logging
import os import os
import pprint
import sys import sys
import textwrap import textwrap
import traceback import traceback
@ -27,8 +26,7 @@ log.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s:%(levelname)s:%(name)s: %(message)s") formatter = logging.Formatter("%(asctime)s:%(levelname)s:%(name)s: %(message)s")
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger_disnake.addHandler(handler) logging.getLogger(None).addHandler(handler)
log.addHandler(handler)
if load_dotenv(find_dotenv(usecwd=True)): if load_dotenv(find_dotenv(usecwd=True)):
log.debug("Loaded .env") log.debug("Loaded .env")
@ -62,7 +60,7 @@ class Nomen(Bot):
intents=options.get("intents"), intents=options.get("intents"),
) )
self.db = self.loop.run_until_complete(setup_db(DB_FILE)) self.db = None # Setup in start
self.prefixes = {} self.prefixes = {}
async def get_guild_prefix(self, guild: Guild): async def get_guild_prefix(self, guild: Guild):
@ -86,6 +84,10 @@ class Nomen(Bot):
) )
self.prefixes[guild.id] = prefix self.prefixes[guild.id] = prefix
async def start(self, *args, **kwargs):
self.db = await setup_db(DB_FILE)
await super().start(*args, **kwargs)
async def close(self): async def close(self):
await super().close() await super().close()
await self.db.close() await self.db.close()
@ -219,6 +221,10 @@ async def prefix(ctx, prefix=None):
def run(): def run():
if run_db_migrations(DB_FILE): try:
log.info(f"Migrated DB {DB_FILE}") if run_db_migrations(DB_FILE):
bot.run(TOKEN) log.info(f"Migrated DB {DB_FILE}")
except RuntimeError:
pass
else:
bot.run(TOKEN)

View file

@ -7,6 +7,8 @@ See <https://david.rothlis.net/declarative-schema-migration-for-sqlite>.
Author: William Manley <will@stb-tester.com>. Author: William Manley <will@stb-tester.com>.
Copyright © 2019-2022 Stb-tester.com Ltd. Copyright © 2019-2022 Stb-tester.com Ltd.
License: MIT. License: MIT.
Modified to ignore internal tables
""" """
import logging import logging
@ -14,6 +16,12 @@ import re
import sqlite3 import sqlite3
from textwrap import dedent from textwrap import dedent
INTERNAL_TABLES = {
"sqlite_stat1",
}
log = logging.getLogger("migrator")
def dumb_migrate_db(db, schema, allow_deletions=False): def dumb_migrate_db(db, schema, allow_deletions=False):
""" """
@ -77,21 +85,19 @@ class DBMigrator:
msg_argv += (args,) msg_argv += (args,)
else: else:
args = [] args = []
logging.info(msg_tmpl, *msg_argv) log.info(msg_tmpl, *msg_argv)
self.db.execute(sql, args) self.db.execute(sql, args)
self.n_changes += 1 self.n_changes += 1
def __enter__(self): def __enter__(self):
self.orig_foreign_keys = ( self.orig_foreign_keys = self.db.execute("PRAGMA foreign_keys").fetchone()[0]
self.db.execute("PRAGMA foreign_keys").fetchone()[0])
if self.orig_foreign_keys: if self.orig_foreign_keys:
self.log_execute("Disable foreign keys temporarily for migration", self.log_execute("Disable foreign keys temporarily for migration", "PRAGMA foreign_keys = OFF")
"PRAGMA foreign_keys = OFF")
# This doesn't count as a change because we'll undo it at the end # This doesn't count as a change because we'll undo it at the end
self.n_changes = 0 self.n_changes = 0
self.db.__enter__() self.db.__enter__()
self.db.execute('BEGIN') self.db.execute("BEGIN")
return self return self
def __exit__(self, exc_type, exc_value, exc_tb): def __exit__(self, exc_type, exc_value, exc_tb):
@ -103,7 +109,7 @@ class DBMigrator:
# > constraint enforcement may only be enabled or disabled when # > constraint enforcement may only be enabled or disabled when
# > there is no pending BEGIN or SAVEPOINT. # > there is no pending BEGIN or SAVEPOINT.
old_changes = self.n_changes old_changes = self.n_changes
new_val = self._migrate_pragma('foreign_keys') new_val = self._migrate_pragma("foreign_keys")
if new_val == self.orig_foreign_keys: if new_val == self.orig_foreign_keys:
self.n_changes = old_changes self.n_changes = old_changes
@ -115,142 +121,133 @@ class DBMigrator:
self.db.execute("VACUUM") self.db.execute("VACUUM")
else: else:
if self.orig_foreign_keys: if self.orig_foreign_keys:
self.log_execute( self.log_execute("Re-enable foreign keys after migration", "PRAGMA foreign_keys = ON")
"Re-enable foreign keys after migration",
"PRAGMA foreign_keys = ON")
def migrate(self): def migrate(self):
# In CI the database schema may be changing all the time. This checks # In CI the database schema may be changing all the time. This checks
# the current db and if it doesn't match database.sql we will # the current db and if it doesn't match database.sql we will
# modify it so it does match where possible. # modify it so it does match where possible.
pristine_tables = dict(self.pristine.execute("""\ pristine_tables = dict(
self.pristine.execute("""\
SELECT name, sql FROM sqlite_master SELECT name, sql FROM sqlite_master
WHERE type = \"table\" AND name != \"sqlite_sequence\"""").fetchall()) WHERE type = \"table\" AND name != \"sqlite_sequence\"""").fetchall()
pristine_indices = dict(self.pristine.execute("""\ )
pristine_indices = dict(
self.pristine.execute("""\
SELECT name, sql FROM sqlite_master SELECT name, sql FROM sqlite_master
WHERE type = \"index\"""").fetchall()) WHERE type = \"index\"""").fetchall()
)
tables = dict(self.db.execute("""\ tables = dict(
self.db.execute("""\
SELECT name, sql FROM sqlite_master SELECT name, sql FROM sqlite_master
WHERE type = \"table\" AND name != \"sqlite_sequence\"""").fetchall()) WHERE type = \"table\" AND name != \"sqlite_sequence\"""").fetchall()
)
new_tables = set(pristine_tables.keys()) - set(tables.keys()) new_tables = set(pristine_tables.keys()) - set(tables.keys()) - INTERNAL_TABLES
removed_tables = set(tables.keys()) - set(pristine_tables.keys()) removed_tables = set(tables.keys()) - set(pristine_tables.keys()) - INTERNAL_TABLES
if removed_tables and not self.allow_deletions: if removed_tables and not self.allow_deletions:
raise RuntimeError( raise RuntimeError("Database migration: Refusing to delete tables %r" % removed_tables)
"Database migration: Refusing to delete tables %r" %
removed_tables)
modified_tables = set( modified_tables = set(
name for name, sql in pristine_tables.items() name for name, sql in pristine_tables.items() if normalise_sql(tables.get(name, "")) != normalise_sql(sql)
if normalise_sql(tables.get(name, "")) != normalise_sql(sql)) )
# This PRAGMA is automatically disabled when the db is committed # This PRAGMA is automatically disabled when the db is committed
self.db.execute("PRAGMA defer_foreign_keys = TRUE") self.db.execute("PRAGMA defer_foreign_keys = TRUE")
# New and removed tables are easy: # New and removed tables are easy:
for tbl_name in new_tables: for tbl_name in new_tables:
self.log_execute("Create table %s" % tbl_name, self.log_execute("Create table %s" % tbl_name, pristine_tables[tbl_name])
pristine_tables[tbl_name])
for tbl_name in removed_tables: for tbl_name in removed_tables:
self.log_execute("Drop table %s" % tbl_name, self.log_execute("Drop table %s" % tbl_name, "DROP TABLE %s" % tbl_name)
"DROP TABLE %s" % tbl_name)
for tbl_name in modified_tables: for tbl_name in modified_tables:
# The SQLite documentation insists that we create the new table and # The SQLite documentation insists that we create the new table and
# rename it over the old rather than moving the old out of the way # rename it over the old rather than moving the old out of the way
# and then creating the new # and then creating the new
create_table_sql = pristine_tables[tbl_name] create_table_sql = pristine_tables[tbl_name]
create_table_sql = re.sub(r"\b%s\b" % re.escape(tbl_name), create_table_sql = re.sub(r"\b%s\b" % re.escape(tbl_name), tbl_name + "_migration_new", create_table_sql)
tbl_name + "_migration_new", self.log_execute("Columns change: Create table %s with updated schema" % tbl_name, create_table_sql)
create_table_sql)
self.log_execute(
"Columns change: Create table %s with updated schema" %
tbl_name, create_table_sql)
cols = set([ cols = set([x[1] for x in self.db.execute("PRAGMA table_info(%s)" % tbl_name)])
x[1] for x in self.db.execute( pristine_cols = set([x[1] for x in self.pristine.execute("PRAGMA table_info(%s)" % tbl_name)])
"PRAGMA table_info(%s)" % tbl_name)])
pristine_cols = set([
x[1] for x in
self.pristine.execute("PRAGMA table_info(%s)" % tbl_name)])
removed_columns = cols - pristine_cols removed_columns = cols - pristine_cols
if not self.allow_deletions and removed_columns: if not self.allow_deletions and removed_columns:
logging.warning( log.warning(
"Database migration: Refusing to remove columns %r from " "Database migration: Refusing to remove columns %r from "
"table %s. Current cols are %r attempting migration to %r", "table %s. Current cols are %r attempting migration to %r",
removed_columns, tbl_name, cols, pristine_cols) removed_columns,
tbl_name,
cols,
pristine_cols,
)
raise RuntimeError( raise RuntimeError(
"Database migration: Refusing to remove columns %r from " "Database migration: Refusing to remove columns %r from " "table %s" % (removed_columns, tbl_name)
"table %s" % (removed_columns, tbl_name)) )
logging.info("cols: %s, pristine_cols: %s", cols, pristine_cols) log.info("cols: %s, pristine_cols: %s", cols, pristine_cols)
self.log_execute( self.log_execute(
"Migrate data for table %s" % tbl_name, """\ "Migrate data for table %s" % tbl_name,
"""\
INSERT INTO {tbl_name}_migration_new ({common}) INSERT INTO {tbl_name}_migration_new ({common})
SELECT {common} FROM {tbl_name}""".format( SELECT {common} FROM {tbl_name}""".format(
tbl_name=tbl_name, tbl_name=tbl_name, common=", ".join(cols.intersection(pristine_cols))
common=", ".join(cols.intersection(pristine_cols)))) ),
)
# Don't need the old table any more # Don't need the old table any more
self.log_execute( self.log_execute("Drop old table %s now data has been migrated" % tbl_name, "DROP TABLE %s" % tbl_name)
"Drop old table %s now data has been migrated" % tbl_name,
"DROP TABLE %s" % tbl_name)
self.log_execute( self.log_execute(
"Columns change: Move new table %s over old" % tbl_name, "Columns change: Move new table %s over old" % tbl_name,
"ALTER TABLE %s_migration_new RENAME TO %s" % ( "ALTER TABLE %s_migration_new RENAME TO %s" % (tbl_name, tbl_name),
tbl_name, tbl_name)) )
# Migrate the indices # Migrate the indices
indices = dict(self.db.execute("""\ indices = dict(
self.db.execute("""\
SELECT name, sql FROM sqlite_master SELECT name, sql FROM sqlite_master
WHERE type = \"index\"""").fetchall()) WHERE type = \"index\"""").fetchall()
)
for name in set(indices.keys()) - set(pristine_indices.keys()): for name in set(indices.keys()) - set(pristine_indices.keys()):
self.log_execute("Dropping obsolete index %s" % name, self.log_execute("Dropping obsolete index %s" % name, "DROP INDEX %s" % name)
"DROP INDEX %s" % name)
for name, sql in pristine_indices.items(): for name, sql in pristine_indices.items():
if name not in indices: if name not in indices:
self.log_execute("Creating new index %s" % name, sql) self.log_execute("Creating new index %s" % name, sql)
elif sql != indices[name]: elif sql != indices[name]:
self.log_execute( self.log_execute("Index %s changed: Dropping old version" % name, "DROP INDEX %s" % name)
"Index %s changed: Dropping old version" % name, self.log_execute("Index %s changed: Creating updated version in its place" % name, sql)
"DROP INDEX %s" % name)
self.log_execute(
"Index %s changed: Creating updated version in its place" %
name, sql)
self._migrate_pragma('user_version') self._migrate_pragma("user_version")
if self.pristine.execute("PRAGMA foreign_keys").fetchone()[0]: if self.pristine.execute("PRAGMA foreign_keys").fetchone()[0]:
if self.db.execute("PRAGMA foreign_key_check").fetchall(): if self.db.execute("PRAGMA foreign_key_check").fetchall():
raise RuntimeError( raise RuntimeError("Database migration: Would fail foreign_key_check")
"Database migration: Would fail foreign_key_check")
def _migrate_pragma(self, pragma): def _migrate_pragma(self, pragma):
pristine_val = self.pristine.execute( pristine_val = self.pristine.execute("PRAGMA %s" % pragma).fetchone()[0]
"PRAGMA %s" % pragma).fetchone()[0]
val = self.db.execute("PRAGMA %s" % pragma).fetchone()[0] val = self.db.execute("PRAGMA %s" % pragma).fetchone()[0]
if val != pristine_val: if val != pristine_val:
self.log_execute( self.log_execute(
"Set %s to %i from %i" % (pragma, pristine_val, val), "Set %s to %i from %i" % (pragma, pristine_val, val), "PRAGMA %s = %i" % (pragma, pristine_val)
"PRAGMA %s = %i" % (pragma, pristine_val)) )
return pristine_val return pristine_val
def _left_pad(text, indent=" "): def _left_pad(text, indent=" "):
"""Maybe I can find a package in pypi for this?""" """Maybe I can find a package in pypi for this?"""
return "\n".join(indent + line for line in text.split('\n')) return "\n".join(indent + line for line in text.split("\n"))
def normalise_sql(sql): def normalise_sql(sql):
# Remove comments: # Remove comments:
sql = re.sub(r'--[^\n]*\n', "", sql) sql = re.sub(r"--[^\n]*\n", "", sql)
# Normalise whitespace: # Normalise whitespace:
sql = re.sub(r'\s+', " ", sql) sql = re.sub(r"\s+", " ", sql)
sql = re.sub(r" *([(),]) *", r"\1", sql) sql = re.sub(r" *([(),]) *", r"\1", sql)
# Remove unnecessary quotes # Remove unnecessary quotes
sql = re.sub(r'"(\w+)"', r"\1", sql) sql = re.sub(r'"(\w+)"', r"\1", sql)
@ -258,8 +255,10 @@ def normalise_sql(sql):
def test_normalise_sql(): def test_normalise_sql():
assert normalise_sql("""\ assert (
normalise_sql("""\
CREATE TABLE "Node"( -- This is my table CREATE TABLE "Node"( -- This is my table
-- There are many like it but this one is mine -- There are many like it but this one is mine
A b, C D, "E F G", h)""") == \ A b, C D, "E F G", h)""")
'CREATE TABLE Node(A b,C D,"E F G",h)' == 'CREATE TABLE Node(A b,C D,"E F G",h)'
)

View file

@ -54,9 +54,9 @@ async def handle_triggers(ctx, message):
disabled = await ctx.bot.db.execute_fetchall( disabled = await ctx.bot.db.execute_fetchall(
"SELECT EXISTS(SELECT * FROM users WHERE user_id=:author AND disabled IS 1)", params "SELECT EXISTS(SELECT * FROM users WHERE user_id=:author AND disabled IS 1)", params
)[0] )
if disabled: if disabled[0][0]:
log.debug(f"User {ctx.author} ({ctx.author.id}) opted out") log.debug(f"User {ctx.author} ({ctx.author.id}) opted out")
return return
@ -152,12 +152,14 @@ class Notifications(Cog):
""" """
if test_keyword(keyword, regex): if test_keyword(keyword, regex):
log.debug("Keyword too common")
await ctx.send(f"{'Regex' if regex else 'Keyword'} matches a word that is too common") await ctx.send(f"{'Regex' if regex else 'Keyword'} matches a word that is too common")
return return
conflicts = await fetch_unpacked(ctx.bot.db, existing, params) conflicts = await fetch_unpacked(ctx.bot.db, existing, params)
if conflicts: if conflicts:
log.debug("Keyword conflicts with existing keyword")
await ctx.send(f"Any instance of `{keyword}` would be matched by existing keywords (check DMs)") await ctx.send(f"Any instance of `{keyword}` would be matched by existing keywords (check DMs)")
await ctx.author.send( await ctx.author.send(
f"Conflicts with keyword `{keyword}`:\n" + "\n".join(f"- `{conflict}`" for conflict in conflicts) f"Conflicts with keyword `{keyword}`:\n" + "\n".join(f"- `{conflict}`" for conflict in conflicts)
@ -167,12 +169,15 @@ class Notifications(Cog):
conflicts = await fetch_unpacked(ctx.bot.db, redundant, params) conflicts = await fetch_unpacked(ctx.bot.db, redundant, params)
if conflicts: if conflicts:
log.debug("Keyword renders existing redundant")
await ctx.send(f"Adding `{keyword}` will cause existing keywords to never match (check DMs)") await ctx.send(f"Adding `{keyword}` will cause existing keywords to never match (check DMs)")
await ctx.author.send( await ctx.author.send(
f"Keywords redundant from `{keyword}`:\n" + "\n".join(f" - `{conflict}`" for conflict in conflicts) f"Keywords redundant from `{keyword}`:\n" + "\n".join(f" - `{conflict}`" for conflict in conflicts)
) )
return return
log.debug("Keyword valid, adding")
await ctx.bot.db.execute( await ctx.bot.db.execute(
"INSERT INTO keywords (guild_id, keyword, user_id, regex) VALUES (:guild_id, :keyword, :user_id, :regex)", "INSERT INTO keywords (guild_id, keyword, user_id, regex) VALUES (:guild_id, :keyword, :user_id, :regex)",
params, params,