convert db helpers to methods
This commit is contained in:
parent
3a4e537828
commit
e31d19bcca
3 changed files with 32 additions and 18 deletions
26
nomen/db.py
26
nomen/db.py
|
@ -70,6 +70,32 @@ def run_db_migrations(db_file):
|
||||||
log.debug("Finished running automatic migration")
|
log.debug("Finished running automatic migration")
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_singleton(db, sql, params=None):
|
||||||
|
"""
|
||||||
|
Fetch an object from the database, with the assumption that the result is 1 row by 1 column
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await db.execute_fetchall(f"{sql} LIMIT 1", params)
|
||||||
|
return result[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_exists(db, sql, params=None):
|
||||||
|
return await fetch_singleton(db, f"SELECT EXISTS({sql})", params)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_unpacked(db, sql, params=None):
|
||||||
|
cur = await db.cursor()
|
||||||
|
cur.row_factory = lambda cursor, row: row[0]
|
||||||
|
cur = await cur.execute(sql, params)
|
||||||
|
return await cur.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
log.debug("Monkeypatching in helpers")
|
||||||
|
aiosqlite.Connection.fetch_singleton = fetch_singleton
|
||||||
|
aiosqlite.Connection.fetch_exists = fetch_exists
|
||||||
|
aiosqlite.Connection.fetch_unpacked = fetch_unpacked
|
||||||
|
|
||||||
|
|
||||||
async def setup_db(db_file):
|
async def setup_db(db_file):
|
||||||
log.debug(f"Connecting to {db_file}")
|
log.debug(f"Connecting to {db_file}")
|
||||||
db = await aiosqlite.connect(db_file)
|
db = await aiosqlite.connect(db_file)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Union
|
||||||
from disnake import Embed, Member, TextChannel
|
from disnake import Embed, Member, TextChannel
|
||||||
from disnake.ext.commands import Cog, group, guild_only
|
from disnake.ext.commands import Cog, group, guild_only
|
||||||
|
|
||||||
from .utils import can_view, confirm, fetch_exists, fetch_unpacked, test_keyword
|
from .utils import can_view, confirm, test_keyword
|
||||||
|
|
||||||
log = logging.getLogger("nomen.notifications")
|
log = logging.getLogger("nomen.notifications")
|
||||||
log.setLevel(logging.DEBUG)
|
log.setLevel(logging.DEBUG)
|
||||||
|
@ -73,8 +73,8 @@ async def handle_triggers(ctx, message):
|
||||||
"is_bot": ctx.author.bot,
|
"is_bot": ctx.author.bot,
|
||||||
}
|
}
|
||||||
|
|
||||||
disabled = await fetch_exists(
|
disabled = await ctx.bot.db.fetch_exists(
|
||||||
ctx.bot.db, "SELECT * FROM users WHERE user_id=:author AND unlikely(disabled IS 1)", params
|
"SELECT * FROM users WHERE user_id=:author AND unlikely(disabled IS 1)", params
|
||||||
)
|
)
|
||||||
|
|
||||||
if disabled:
|
if disabled:
|
||||||
|
@ -181,7 +181,7 @@ class Notifications(Cog):
|
||||||
await ctx.send(f"{'Regex' if regex else 'Keyword'} matches a word that is too common")
|
await ctx.send(f"{'Regex' if regex else 'Keyword'} matches a word that is too common")
|
||||||
return
|
return
|
||||||
|
|
||||||
conflicts = await fetch_unpacked(ctx.bot.db, existing, params)
|
conflicts = await ctx.bot.db.fetch_unpacked(existing, params)
|
||||||
|
|
||||||
if conflicts:
|
if conflicts:
|
||||||
log.debug("Keyword conflicts with existing keyword")
|
log.debug("Keyword conflicts with existing keyword")
|
||||||
|
@ -191,7 +191,7 @@ class Notifications(Cog):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
conflicts = await fetch_unpacked(ctx.bot.db, redundant, params)
|
conflicts = await ctx.bot.db.fetch_unpacked(redundant, params)
|
||||||
|
|
||||||
if conflicts:
|
if conflicts:
|
||||||
log.debug("Keyword renders existing redundant")
|
log.debug("Keyword renders existing redundant")
|
||||||
|
@ -304,7 +304,7 @@ class Notifications(Cog):
|
||||||
@guild_only()
|
@guild_only()
|
||||||
async def pause(self, ctx):
|
async def pause(self, ctx):
|
||||||
params = (ctx.author.id, ctx.guild_id)
|
params = (ctx.author.id, ctx.guild_id)
|
||||||
if await fetch_exists("SELECT * FROM user_pauses WHERE user_id=? AND guild_id=?"):
|
if await ctx.bot.db.fetch_exists("SELECT * FROM user_pauses WHERE user_id=? AND guild_id=?", params):
|
||||||
await ctx.bot.db.execute("DELETE FROM user_pauses WHERE user_id=? AND guild_id=?", params)
|
await ctx.bot.db.execute("DELETE FROM user_pauses WHERE user_id=? AND guild_id=?", params)
|
||||||
await ctx.bot.send(f"Resumed notifications in {ctx.guild}")
|
await ctx.bot.send(f"Resumed notifications in {ctx.guild}")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -111,18 +111,6 @@ def unpack(lst_of_tpl):
|
||||||
return list(map(first, lst_of_tpl))
|
return list(map(first, lst_of_tpl))
|
||||||
|
|
||||||
|
|
||||||
async def fetch_unpacked(db, sql, params=None):
|
|
||||||
cur = await db.cursor()
|
|
||||||
cur.row_factory = lambda cursor, row: first(row)
|
|
||||||
cur = await cur.execute(sql, params)
|
|
||||||
return await cur.fetchall()
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_exists(db, sql, params=None):
|
|
||||||
result = await db.execute_fetchall(f"SELECT EXISTS({sql})", params)
|
|
||||||
return result[0][0]
|
|
||||||
|
|
||||||
|
|
||||||
async def in_thread(member, thread):
|
async def in_thread(member, thread):
|
||||||
# FIXME: Currently overlooks the situation where a moderator isn't in a thread but has manage threads
|
# FIXME: Currently overlooks the situation where a moderator isn't in a thread but has manage threads
|
||||||
return any(member.id == thread_member.id for thread_member in await thread.fetch_members())
|
return any(member.id == thread_member.id for thread_member in await thread.fetch_members())
|
||||||
|
|
Loading…
Reference in a new issue