From e31d19bcca9292da805dde6d226d30a39b8e1449 Mon Sep 17 00:00:00 2001 From: Infinidoge Date: Tue, 31 Dec 2024 14:14:46 -0500 Subject: [PATCH] convert db helpers to methods --- nomen/db.py | 26 ++++++++++++++++++++++++++ nomen/notifications.py | 12 ++++++------ nomen/utils.py | 12 ------------ 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/nomen/db.py b/nomen/db.py index 07f594c..7a38522 100644 --- a/nomen/db.py +++ b/nomen/db.py @@ -70,6 +70,32 @@ def run_db_migrations(db_file): log.debug("Finished running automatic migration") +async def fetch_singleton(db, sql, params=None): + """ + Fetch an object from the database, with the assumption that the result is 1 row by 1 column + """ + + result = await db.execute_fetchall(f"{sql} LIMIT 1", params) + return result[0][0] + + +async def fetch_exists(db, sql, params=None): + return await fetch_singleton(db, f"SELECT EXISTS({sql})", params) + + +async def fetch_unpacked(db, sql, params=None): + cur = await db.cursor() + cur.row_factory = lambda cursor, row: row[0] + cur = await cur.execute(sql, params) + return await cur.fetchall() + + +log.debug("Monkeypatching in helpers") +aiosqlite.Connection.fetch_singleton = fetch_singleton +aiosqlite.Connection.fetch_exists = fetch_exists +aiosqlite.Connection.fetch_unpacked = fetch_unpacked + + async def setup_db(db_file): log.debug(f"Connecting to {db_file}") db = await aiosqlite.connect(db_file) diff --git a/nomen/notifications.py b/nomen/notifications.py index 48e5de1..704964d 100644 --- a/nomen/notifications.py +++ b/nomen/notifications.py @@ -6,7 +6,7 @@ from typing import Union from disnake import Embed, Member, TextChannel from disnake.ext.commands import Cog, group, guild_only -from .utils import can_view, confirm, fetch_exists, fetch_unpacked, test_keyword +from .utils import can_view, confirm, test_keyword log = logging.getLogger("nomen.notifications") log.setLevel(logging.DEBUG) @@ -73,8 +73,8 @@ async def handle_triggers(ctx, message): "is_bot": ctx.author.bot, } - disabled = await fetch_exists( - ctx.bot.db, "SELECT * FROM users WHERE user_id=:author AND unlikely(disabled IS 1)", params + disabled = await ctx.bot.db.fetch_exists( + "SELECT * FROM users WHERE user_id=:author AND unlikely(disabled IS 1)", params ) if disabled: @@ -181,7 +181,7 @@ class Notifications(Cog): await ctx.send(f"{'Regex' if regex else 'Keyword'} matches a word that is too common") return - conflicts = await fetch_unpacked(ctx.bot.db, existing, params) + conflicts = await ctx.bot.db.fetch_unpacked(existing, params) if conflicts: log.debug("Keyword conflicts with existing keyword") @@ -191,7 +191,7 @@ class Notifications(Cog): ) return - conflicts = await fetch_unpacked(ctx.bot.db, redundant, params) + conflicts = await ctx.bot.db.fetch_unpacked(redundant, params) if conflicts: log.debug("Keyword renders existing redundant") @@ -304,7 +304,7 @@ class Notifications(Cog): @guild_only() async def pause(self, ctx): params = (ctx.author.id, ctx.guild_id) - if await fetch_exists("SELECT * FROM user_pauses WHERE user_id=? AND guild_id=?"): + if await ctx.bot.db.fetch_exists("SELECT * FROM user_pauses WHERE user_id=? AND guild_id=?", params): await ctx.bot.db.execute("DELETE FROM user_pauses WHERE user_id=? AND guild_id=?", params) await ctx.bot.send(f"Resumed notifications in {ctx.guild}") else: diff --git a/nomen/utils.py b/nomen/utils.py index 3b9861e..417a9c8 100644 --- a/nomen/utils.py +++ b/nomen/utils.py @@ -111,18 +111,6 @@ def unpack(lst_of_tpl): return list(map(first, lst_of_tpl)) -async def fetch_unpacked(db, sql, params=None): - cur = await db.cursor() - cur.row_factory = lambda cursor, row: first(row) - cur = await cur.execute(sql, params) - return await cur.fetchall() - - -async def fetch_exists(db, sql, params=None): - result = await db.execute_fetchall(f"SELECT EXISTS({sql})", params) - return result[0][0] - - async def in_thread(member, thread): # FIXME: Currently overlooks the situation where a moderator isn't in a thread but has manage threads return any(member.id == thread_member.id for thread_member in await thread.fetch_members())