from __future__ import annotations
import concurrent.futures
import importlib
import logging
import typing as t
import hikari
from dawn.commands.slash.base import Option, SlashCallable
from dawn.commands.slash.command import SlashCommand
from dawn.commands.slash.groups import SlashGroup
from dawn.context.slash import SlashContext
from dawn.errors import CommandAlreadyExists, ModuleAlreadyLoaded
from dawn.extensions import Extension
from dawn.internals import CommandManager
__all__: t.Tuple[str, ...] = ("Bot",)
_LOGGER = logging.getLogger("dawn.bot")
[docs]class Bot(hikari.GatewayBot, CommandManager):
"""The handler's Bot class.
This is a subclass of :class:`~hikari.GatewayBot` with all the features of
the parent class supported.
Parameters
----------
token: :class:`str`
The bot token for your application.
default_guilds: :class:`Optional[int]`
Default guilds in which the commands will be added.
purge_extra: :class:`bool`
Commands not bound with this class get deleted if set to `True`.
"""
__slots__: t.Tuple[str, ...] = (
"_slash_commands",
"_purge_extra",
"_slash_groups",
"_loaded_modules",
"_extensions",
"default_guilds",
)
def __init__(
self,
token: str,
*,
default_guilds: t.Sequence[int] | None = None,
purge_extra: bool = False,
allow_color: bool = True,
banner: str | None = "dawn",
executor: concurrent.futures.Executor | None = None,
force_color: bool = True,
cache_settings: hikari.impl.CacheSettings | None = None,
intents: hikari.Intents = hikari.Intents.ALL_UNPRIVILEGED,
auto_chunk_members: bool = True,
logs: None | int | str | t.Dict[str, t.Any] = "INFO",
max_rate_limit: float = 300,
max_retries: int = 3,
proxy_settings: hikari.impl.ProxySettings | None = None,
rest_url: str | None = None,
) -> None:
self._slash_commands: t.Dict[str, SlashCommand] = {}
self._slash_groups: t.Dict[str, SlashGroup] = {}
self._purge_extra = purge_extra
self._loaded_modules: t.List[str] = []
self._extensions: t.Dict[str, Extension] = {}
self.default_guilds = default_guilds
self._LOGGER = _LOGGER
super().__init__(
token,
allow_color=allow_color,
banner=banner,
executor=executor,
force_color=force_color,
cache_settings=cache_settings,
intents=intents,
auto_chunk_members=auto_chunk_members,
logs=logs,
max_rate_limit=max_rate_limit,
max_retries=max_retries,
proxy_settings=proxy_settings,
rest_url=rest_url,
)
self.event_manager.subscribe(
hikari.InteractionCreateEvent, self.process_slash_commands
)
self.event_manager.subscribe(hikari.StartedEvent, self._update_commands)
self.event_manager.subscribe(
hikari.InteractionCreateEvent, self.process_autocomplete
)
@property
def slash_commands(self) -> t.Mapping[str, SlashCommand]:
"""Mapping for slash command names to :class:`SlashCommand` objects."""
return self._slash_commands
@property
def slash_groups(self) -> t.Mapping[str, SlashGroup]:
"""Mapping for slash group names to :class:`SlashGroup` objects."""
return self._slash_groups
@property
def extensions(self) -> t.Mapping[str, Extension]:
"""List of all extensions."""
return self._extensions
[docs] def add_slash_command(self, command: SlashCommand) -> None:
"""Add a slash command to the bot's bucket.
Parameters
----------
command: :class:`SlashCommand`
The slash command to be added to bucket.
"""
if self._slash_commands.get(name := command.name):
raise CommandAlreadyExists(name)
self._slash_commands[name] = command
[docs] def add_slash_group(self, group: SlashGroup) -> None:
"""Add a slash command group to bot's bucket.
Parameters
----------
group: :class:`SlashGroup`
The group to be added.
"""
if self._slash_groups.get(name := group.name) or self._slash_commands.get(name):
raise CommandAlreadyExists(name)
self._slash_groups[name] = group
[docs] def get_slash_context(self, event: hikari.InteractionCreateEvent) -> SlashContext:
"""Wrap an :class:`~hikari.InteractionCreateEvent` into a :class:`SlashContext` class."""
return SlashContext(self, event)
[docs] def get_slash_command(self, name: str, /) -> SlashCommand | None:
"""Gets a :class:`.SlashCommand` by its name.
Parameters
----------
name: :class:`str`
Name of the command.
"""
return self._slash_commands.get(name)
[docs] def get_extension(self, name: str, /) -> Extension | None:
"""Get a loaded extension object using it's name.
Parameters
----------
name: :class:`str`
Name of the extension.
"""
return self._extensions.get(name)
[docs] def slash(self, command: SlashCommand) -> SlashCommand:
"""
Use this decorator to add a slash command to the bot.
Example
-------
>>> import dawn
>>>
>>> bot = dawn.Bot("TOKEN")
>>>
>>> @bot.slash
>>> @dawn.slash_command("ping")
>>> async def ping(context: dawn.SlashContext) -> None:
>>> await context.create_response("pong!")
>>>
>>> bot.run()
"""
def inner() -> SlashCommand:
nonlocal command
self.add_slash_command(command)
return command
return inner()
[docs] def add_extension(self, extension: Extension, /) -> None:
"""
Adds an extension to the bot.
Parameters
----------
extension: :class:`.Extension`
The extension to be loaded.
Raises
------
:class:`ValueError`
The extension provided wasn't dervied from :class:`.Extension` class.
"""
if not isinstance(extension, Extension):
raise ValueError(f"Expected a `dawn.Extension`, got {type(extension)}")
extension.create_setup(self)
[docs] def load_module(self, module_path: str, /) -> None:
"""
Loads a module and calls the `load` function of the module.
Parameters
----------
module_path: :class:`str`
Path to the module.
"""
if module_path in self._loaded_modules:
raise ModuleAlreadyLoaded(module_path)
ext = importlib.import_module(module_path)
if not (load_function := getattr(ext, "load")):
raise Exception("No load function found.")
else:
load_function(self)
self._loaded_modules.append(module_path)
[docs] async def process_slash_commands(
self, event: hikari.InteractionCreateEvent, /
) -> None:
"""Filters and processes the slash command interactions.
Parameters
----------
event: :class:`~hikari.InteractionCreateEvent`
The event related to this call.
"""
if isinstance(inter := event.interaction, hikari.CommandInteraction):
if command := self._slash_commands.get(inter.command_name):
await self.invoke_slash_command(event, command)
elif group := self._slash_groups.get(inter.command_name):
await self.invoke_slash_command(event, group)
[docs] async def invoke_slash_command(
self, event: hikari.InteractionCreateEvent, command: SlashCommand | SlashGroup
) -> None:
"""Executes a processed :class:`.SlashCommand`.
Parameters
----------
event: :class:`hikari.InteractionCreateEvent`
The event related to this call.
command: :class:`.SlashCommand`
The slash command to process.
"""
if not isinstance(inter := event.interaction, hikari.CommandInteraction):
return
if isinstance(command, SlashCommand):
kwargs = await self._prepare_kwargs(inter, inter.options or [])
await command(self.get_slash_context(event), **kwargs)
elif (
isinstance(command, SlashGroup)
and (sub := [opt for opt in inter.options or []][0]).type
== hikari.OptionType.SUB_COMMAND
):
if not (to_call := command._subcommands.get((sub.name))):
print(sub)
raise Exception("..")
kwargs = await self._prepare_kwargs(inter, sub.options or [])
await to_call(self.get_slash_context(event), **kwargs)
elif (
isinstance(command, SlashGroup)
and (sub := [opt for opt in inter.options or []][0]).type
== hikari.OptionType.SUB_COMMAND_GROUP
):
if not (c_group := self.slash_groups.get(inter.command_name)):
return
if not (c_subgroup := c_group._subgroups.get(sub.name)):
return
if not (
sub_command := c_subgroup._subcommands.get(
(sub_command_options := [option for option in sub.options or []])[
0
].name
)
):
return
kwargs = await self._prepare_kwargs(
inter, sub_command_options[0].options or []
)
await sub_command(self.get_slash_context(event), **kwargs)
async def process_autocomplete(self, event: hikari.InteractionCreateEvent) -> None:
if not isinstance(inter := event.interaction, hikari.AutocompleteInteraction):
return
if command := self._slash_commands.get(inter.command_name):
await self.trigger_autcomplete_for(command, inter)
elif group := self._slash_groups.get(inter.command_name):
await self.trigger_autcomplete_for(group, inter)
async def trigger_autcomplete_for(
self, command: SlashCommand | SlashGroup, inter: hikari.AutocompleteInteraction
) -> None:
if isinstance(command, SlashCommand):
target_option = [option for option in inter.options if option.is_focused][0]
if autocomplete := command.autocompletes.get(target_option.name):
responses = await autocomplete(inter, target_option)
await self._send_triggered_autocompletes(inter, responses)
elif isinstance(command, SlashGroup) and (
(sub := inter.options[0]).type == hikari.OptionType.SUB_COMMAND
):
target_option = (
[option for option in (sub.options or []) if option.is_focused]
)[0]
if autocomplete := command._subcommands[sub.name].autocompletes.get(
target_option.name
): #
responses = await autocomplete(inter, target_option)
await self._send_triggered_autocompletes(inter, responses)
elif (
isinstance(command, SlashGroup)
and (sub := [opt for opt in inter.options or []][0]).type
== hikari.OptionType.SUB_COMMAND_GROUP
):
if not (c_group := self.slash_groups.get(inter.command_name)):
return
if not (c_subgroup := c_group._subgroups.get(sub.name)):
return
if not (
sub_command := c_subgroup._subcommands.get(
(sub_command_options := [option for option in sub.options or []])[
0
].name
)
):
return
target_option = [option for option in sub_command_options][0]
if autocomplete := sub_command.autocompletes.get(target_option.name):
responses = await autocomplete(inter, target_option)
await self._send_triggered_autocompletes(inter, responses)
async def _send_triggered_autocompletes(
self, inter: hikari.AutocompleteInteraction, responses: list[t.Any]
) -> None:
autocompletes = [
(
choice
if isinstance(choice, hikari.CommandChoice)
else hikari.CommandChoice(name=str(choice), value=str(choice))
)
for choice in responses
]
await inter.create_response(autocompletes)
async def _prepare_kwargs(
self,
inter: hikari.CommandInteraction,
options: t.Sequence[hikari.CommandInteractionOption],
) -> t.Mapping[str, t.Any]:
kwargs: t.Dict[str, t.Any] = {}
for opt in options or []:
if opt.type == hikari.OptionType.CHANNEL and isinstance(opt.value, int):
kwargs[opt.name] = self.cache.get_guild_channel(opt.value)
elif opt.type == hikari.OptionType.USER and isinstance(opt.value, int):
if (g_id := inter.guild_id) is None:
kwargs[opt.name] = None
else:
kwargs[opt.name] = self.cache.get_member(
g_id, opt.value
) or await self.rest.fetch_member(g_id, opt.value)
elif opt.type == hikari.OptionType.ROLE and isinstance(opt.value, int):
if not inter.guild_id:
kwargs[opt.name] = None
else:
kwargs[opt.name] = self.cache.get_role(opt.value)
elif opt.type == hikari.OptionType.ATTACHMENT and isinstance(
opt.value, hikari.Snowflake
):
if (res := inter.resolved) is None:
raise Exception("")
attachment = res.attachments.get(opt.value)
kwargs[opt.name] = attachment
else:
kwargs[opt.name] = opt.value
return kwargs
def _has_guild_binded(self, command: SlashCommand | SlashGroup) -> bool:
return (
True
if any(
[
command.guild_ids,
(command.extension.default_guilds if command.extension else []),
self.default_guilds,
]
)
else False
)
async def _update_commands(self, event: hikari.StartedEvent) -> None:
await self._handle_global_command(self)
await self._handle_global_groups(self)
await self._handle_guild_commands(self)
await self._handle_guild_groups(self)