convert db helpers to methods

This commit is contained in:
Infinidoge 2024-12-31 14:14:46 -05:00
parent 3a4e537828
commit e31d19bcca
Signed by: Infinidoge
SSH key fingerprint: SHA256:GT2StvPQMMfFHyiiFJymQxfTG/z6EWLJ6NWItf5K5sA
3 changed files with 32 additions and 18 deletions

View file

@ -70,6 +70,32 @@ def run_db_migrations(db_file):
log.debug("Finished running automatic migration") log.debug("Finished running automatic migration")
async def fetch_singleton(db, sql, params=None):
"""
Fetch an object from the database, with the assumption that the result is 1 row by 1 column
"""
result = await db.execute_fetchall(f"{sql} LIMIT 1", params)
return result[0][0]
async def fetch_exists(db, sql, params=None):
return await fetch_singleton(db, f"SELECT EXISTS({sql})", params)
async def fetch_unpacked(db, sql, params=None):
cur = await db.cursor()
cur.row_factory = lambda cursor, row: row[0]
cur = await cur.execute(sql, params)
return await cur.fetchall()
log.debug("Monkeypatching in helpers")
aiosqlite.Connection.fetch_singleton = fetch_singleton
aiosqlite.Connection.fetch_exists = fetch_exists
aiosqlite.Connection.fetch_unpacked = fetch_unpacked
async def setup_db(db_file): async def setup_db(db_file):
log.debug(f"Connecting to {db_file}") log.debug(f"Connecting to {db_file}")
db = await aiosqlite.connect(db_file) db = await aiosqlite.connect(db_file)

View file

@ -6,7 +6,7 @@ from typing import Union
from disnake import Embed, Member, TextChannel from disnake import Embed, Member, TextChannel
from disnake.ext.commands import Cog, group, guild_only from disnake.ext.commands import Cog, group, guild_only
from .utils import can_view, confirm, fetch_exists, fetch_unpacked, test_keyword from .utils import can_view, confirm, test_keyword
log = logging.getLogger("nomen.notifications") log = logging.getLogger("nomen.notifications")
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
@ -73,8 +73,8 @@ async def handle_triggers(ctx, message):
"is_bot": ctx.author.bot, "is_bot": ctx.author.bot,
} }
disabled = await fetch_exists( disabled = await ctx.bot.db.fetch_exists(
ctx.bot.db, "SELECT * FROM users WHERE user_id=:author AND unlikely(disabled IS 1)", params "SELECT * FROM users WHERE user_id=:author AND unlikely(disabled IS 1)", params
) )
if disabled: if disabled:
@ -181,7 +181,7 @@ class Notifications(Cog):
await ctx.send(f"{'Regex' if regex else 'Keyword'} matches a word that is too common") await ctx.send(f"{'Regex' if regex else 'Keyword'} matches a word that is too common")
return return
conflicts = await fetch_unpacked(ctx.bot.db, existing, params) conflicts = await ctx.bot.db.fetch_unpacked(existing, params)
if conflicts: if conflicts:
log.debug("Keyword conflicts with existing keyword") log.debug("Keyword conflicts with existing keyword")
@ -191,7 +191,7 @@ class Notifications(Cog):
) )
return return
conflicts = await fetch_unpacked(ctx.bot.db, redundant, params) conflicts = await ctx.bot.db.fetch_unpacked(redundant, params)
if conflicts: if conflicts:
log.debug("Keyword renders existing redundant") log.debug("Keyword renders existing redundant")
@ -304,7 +304,7 @@ class Notifications(Cog):
@guild_only() @guild_only()
async def pause(self, ctx): async def pause(self, ctx):
params = (ctx.author.id, ctx.guild_id) params = (ctx.author.id, ctx.guild_id)
if await fetch_exists("SELECT * FROM user_pauses WHERE user_id=? AND guild_id=?"): if await ctx.bot.db.fetch_exists("SELECT * FROM user_pauses WHERE user_id=? AND guild_id=?", params):
await ctx.bot.db.execute("DELETE FROM user_pauses WHERE user_id=? AND guild_id=?", params) await ctx.bot.db.execute("DELETE FROM user_pauses WHERE user_id=? AND guild_id=?", params)
await ctx.bot.send(f"Resumed notifications in {ctx.guild}") await ctx.bot.send(f"Resumed notifications in {ctx.guild}")
else: else:

View file

@ -111,18 +111,6 @@ def unpack(lst_of_tpl):
return list(map(first, lst_of_tpl)) return list(map(first, lst_of_tpl))
async def fetch_unpacked(db, sql, params=None):
cur = await db.cursor()
cur.row_factory = lambda cursor, row: first(row)
cur = await cur.execute(sql, params)
return await cur.fetchall()
async def fetch_exists(db, sql, params=None):
result = await db.execute_fetchall(f"SELECT EXISTS({sql})", params)
return result[0][0]
async def in_thread(member, thread): async def in_thread(member, thread):
# FIXME: Currently overlooks the situation where a moderator isn't in a thread but has manage threads # FIXME: Currently overlooks the situation where a moderator isn't in a thread but has manage threads
return any(member.id == thread_member.id for thread_member in await thread.fetch_members()) return any(member.id == thread_member.id for thread_member in await thread.fetch_members())