|
|
@@ -1,56 +1,64 @@
|
|
|
import asyncio
|
|
|
import random
|
|
|
-import time
|
|
|
-from io import BytesIO
|
|
|
-from threading import Thread
|
|
|
|
|
|
import discord
|
|
|
import torch
|
|
|
|
|
|
-from make_image import make_image
|
|
|
+from messages import SORRY_SLOW, NO_GPU
|
|
|
+from thread_queue import ThreadQueue
|
|
|
|
|
|
|
|
|
class Bot:
|
|
|
- SORRY = [
|
|
|
- 'my server sucks',
|
|
|
- 'my server is slow as fuck',
|
|
|
- "i'm running on a toaster",
|
|
|
- "i'm running on a flip-phone",
|
|
|
- "i'm running on a tamagochi",
|
|
|
- "i'm running on a potato",
|
|
|
- "i'm not exactly running on a nasa computer",
|
|
|
- 'i lack the recources to do this fast',
|
|
|
- "i'm sleepy",
|
|
|
- "i'm lazy",
|
|
|
- ]
|
|
|
- NO_GPU = [
|
|
|
- 'my server lacks a graphics card',
|
|
|
- "i'm running on a v-server with no gpu",
|
|
|
- 'my creator is too poor to get a server with gpu',
|
|
|
- 'i lack a gpu',
|
|
|
- "i can't run on cuda (no gpu)"
|
|
|
- ]
|
|
|
-
|
|
|
- def __init__(self, secret: str):
|
|
|
- self._client = None
|
|
|
+ """
|
|
|
+ Main bot class, handling messages to and from the bot.
|
|
|
+
|
|
|
+ Works the bot_queue to send queued messages.
|
|
|
+ on_message, will schedule drawing prompts on the worker_queue
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, secret: str, bot_queue: ThreadQueue, worker_queue: ThreadQueue):
|
|
|
+ self._client: discord.Client = None
|
|
|
self._secret = secret
|
|
|
+ self._worker_queue = worker_queue
|
|
|
+ self._bot_queue = bot_queue
|
|
|
|
|
|
+ self._busy = False
|
|
|
+ self._channels = {}
|
|
|
self._setup()
|
|
|
|
|
|
- def run(self):
|
|
|
- # note: blocking!
|
|
|
- self._client.run(self._secret)
|
|
|
+ async def run(self):
|
|
|
+ """
|
|
|
+ start the bot itself and start listening to the bot_queue
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ return await asyncio.gather(
|
|
|
+ self._start(),
|
|
|
+ self._work_queue(),
|
|
|
+ )
|
|
|
|
|
|
async def on_ready(self):
|
|
|
print(f'We have logged in as {self._client.user}')
|
|
|
|
|
|
- async def on_message(self, message):
|
|
|
- print(f'{message.author}: {message.clean_content} ({message.mentions})')
|
|
|
- if message.author == self._client.user:
|
|
|
- return
|
|
|
-
|
|
|
+ async def on_message(self, message: discord.Message):
|
|
|
if message.clean_content.startswith(self._msg_prefix):
|
|
|
- await self._draw_image(message)
|
|
|
+ self._channels[message.channel.id] = message.channel
|
|
|
+ await self._queue_prompt(message)
|
|
|
+
|
|
|
+ async def _work_queue(self):
|
|
|
+ while True:
|
|
|
+ message = await self._bot_queue.get()
|
|
|
+ busy = message.get('busy', None)
|
|
|
+ if busy is not None:
|
|
|
+ self._busy = busy
|
|
|
+ continue
|
|
|
+
|
|
|
+ channel = self._channels[message.pop('channel_id')]
|
|
|
+ text = message.pop('message')
|
|
|
+ await channel.send(text, **message)
|
|
|
+
|
|
|
+ async def _start(self):
|
|
|
+ async with self._client:
|
|
|
+ await self._client.start(self._secret)
|
|
|
|
|
|
@property
|
|
|
def _msg_prefix(self):
|
|
|
@@ -64,53 +72,28 @@ class Bot:
|
|
|
self._client.event(self.on_ready)
|
|
|
self._client.event(self.on_message)
|
|
|
|
|
|
- async def _draw_image(self, message: discord.Message):
|
|
|
- await self._ack_prompt(message)
|
|
|
-
|
|
|
- prompt = message.clean_content[len(self._msg_prefix) + 1:]
|
|
|
- print(f'drawing an image, prompt="{prompt}"')
|
|
|
-
|
|
|
- image_grid = await self._make_image(prompt)
|
|
|
-
|
|
|
- buffer = BytesIO()
|
|
|
- image_grid['result'].convert('RGB')
|
|
|
- image_grid['result'].save(buffer, format='JPEG')
|
|
|
-
|
|
|
- buffer.seek(0)
|
|
|
- picture = discord.File(buffer, filename='result.jpeg')
|
|
|
-
|
|
|
- if image_grid['nsfw'][0]:
|
|
|
- await message.channel.send(
|
|
|
- "oh, this one's spicy (according to the nsfw filter, that i don't care about)",
|
|
|
- )
|
|
|
-
|
|
|
- await message.channel.send(f'{message.author.mention} {prompt}:', file=picture)
|
|
|
-
|
|
|
- async def _ack_prompt(self, message: discord.Message):
|
|
|
- messages = self.SORRY
|
|
|
+ async def _ack_reply(self, message: discord.Message):
|
|
|
+ # send a message to the channel to acknowledge the receipt of the drawing prompt
|
|
|
+ # add a random sorry message, because this shit is pretty fucking slow.
|
|
|
+ messages = SORRY_SLOW
|
|
|
if not torch.cuda.is_available():
|
|
|
- messages += self.NO_GPU
|
|
|
+ # honestly, if you end up here, you'll probably wait an hour for the drawing
|
|
|
+ messages += NO_GPU
|
|
|
|
|
|
sorry = random.choice(messages)
|
|
|
- await message.channel.send(f'No problem. Just give me some time, {sorry}.')
|
|
|
|
|
|
- async def _make_image(self, prompt: str) -> dict:
|
|
|
- def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
|
|
- asyncio.set_event_loop(loop)
|
|
|
- loop.run_forever()
|
|
|
+ queue_pos = ''
|
|
|
+ if self._busy:
|
|
|
+ # if something is in the queue, the worker is currently busy
|
|
|
+ queue_pos = f'waiting for {self._worker_queue.qsize() + 1} other picture(s) tho.'
|
|
|
|
|
|
- async def task(_prompt) -> dict:
|
|
|
- return make_image(_prompt)
|
|
|
+ await message.channel.send(f'No problem. Just give me some time, {sorry}. {queue_pos}')
|
|
|
|
|
|
- loop = asyncio.new_event_loop()
|
|
|
- t0 = time.time()
|
|
|
-
|
|
|
- thread = Thread(target=start_background_loop, args=(loop,), daemon=True)
|
|
|
- thread.start()
|
|
|
-
|
|
|
- task = asyncio.run_coroutine_threadsafe(task(prompt), loop)
|
|
|
- result = task.result()
|
|
|
-
|
|
|
- loop.stop()
|
|
|
- print(f'prompt rendered in {time.time() - t0}s')
|
|
|
- return result
|
|
|
+ async def _queue_prompt(self, message: discord.Message):
|
|
|
+ prompt = message.clean_content[len(self._msg_prefix) + 1:]
|
|
|
+ await self._ack_reply(message)
|
|
|
+ await self._worker_queue.put({
|
|
|
+ 'prompt': prompt,
|
|
|
+ 'channel_id': message.channel.id,
|
|
|
+ 'sender_mention': message.author.mention,
|
|
|
+ })
|