2
0
Prechádzať zdrojové kódy

Run in dedicated worker

Locutus 3 rokov pred
rodič
commit
a31c83e4c2
5 zmenil súbory, kde vykonal 199 pridanie a 82 odobranie
  1. 62 79
      arty/bot.py
  2. 43 3
      arty/main.py
  3. 19 0
      arty/messages.py
  4. 25 0
      arty/thread_queue.py
  5. 50 0
      arty/worker.py

+ 62 - 79
arty/bot.py

@@ -1,56 +1,64 @@
 import asyncio
 import random
-import time
-from io import BytesIO
-from threading import Thread
 
 import discord
 import torch
 
-from make_image import make_image
+from messages import SORRY_SLOW, NO_GPU
+from thread_queue import ThreadQueue
 
 
 class Bot:
-    SORRY = [
-        'my server sucks',
-        'my server is slow as fuck',
-        "i'm running on a toaster",
-        "i'm running on a flip-phone",
-        "i'm running on a tamagochi",
-        "i'm running on a potato",
-        "i'm not exactly running on a nasa computer",
-        'i lack the recources to do this fast',
-        "i'm sleepy",
-        "i'm lazy",
-    ]
-    NO_GPU = [
-        'my server lacks a graphics card',
-        "i'm running on a v-server with no gpu",
-        'my creator is too poor to get a server with gpu',
-        'i lack a gpu',
-        "i can't run on cuda (no gpu)"
-    ]
-
-    def __init__(self, secret: str):
-        self._client = None
+    """
+    Main bot class, handling messages to and from the bot.
+
+    Works the bot_queue to send queued messages.
+    on_message, will schedule drawing prompts on the worker_queue
+    """
+
+    def __init__(self, secret: str, bot_queue: ThreadQueue, worker_queue: ThreadQueue):
+        self._client: discord.Client = None
         self._secret = secret
+        self._worker_queue = worker_queue
+        self._bot_queue = bot_queue
 
+        self._busy = False
+        self._channels = {}
         self._setup()
 
-    def run(self):
-        # note: blocking!
-        self._client.run(self._secret)
+    async def run(self):
+        """
+        start the bot itself and start listening to the bot_queue
+        :return:
+        """
+        return await asyncio.gather(
+            self._start(),
+            self._work_queue(),
+        )
 
     async def on_ready(self):
         print(f'We have logged in as {self._client.user}')
 
-    async def on_message(self, message):
-        print(f'{message.author}: {message.clean_content} ({message.mentions})')
-        if message.author == self._client.user:
-            return
-
+    async def on_message(self, message: discord.Message):
         if message.clean_content.startswith(self._msg_prefix):
-            await self._draw_image(message)
+            self._channels[message.channel.id] = message.channel
+            await self._queue_prompt(message)
+
+    async def _work_queue(self):
+        while True:
+            message = await self._bot_queue.get()
+            busy = message.get('busy', None)
+            if busy is not None:
+                self._busy = busy
+                continue
+
+            channel = self._channels[message.pop('channel_id')]
+            text = message.pop('message')
+            await channel.send(text, **message)
+
+    async def _start(self):
+        async with self._client:
+            await self._client.start(self._secret)
 
     @property
     def _msg_prefix(self):
@@ -64,53 +72,28 @@ class Bot:
         self._client.event(self.on_ready)
         self._client.event(self.on_message)
 
-    async def _draw_image(self, message: discord.Message):
-        await self._ack_prompt(message)
-
-        prompt = message.clean_content[len(self._msg_prefix) + 1:]
-        print(f'drawing an image, prompt="{prompt}"')
-
-        image_grid = await self._make_image(prompt)
-
-        buffer = BytesIO()
-        image_grid['result'].convert('RGB')
-        image_grid['result'].save(buffer, format='JPEG')
-
-        buffer.seek(0)
-        picture = discord.File(buffer, filename='result.jpeg')
-
-        if image_grid['nsfw'][0]:
-            await message.channel.send(
-                "oh, this one's spicy (according to the nsfw filter, that i don't care about)",
-            )
-
-        await message.channel.send(f'{message.author.mention} {prompt}:', file=picture)
-
-    async def _ack_prompt(self, message: discord.Message):
-        messages = self.SORRY
+    async def _ack_reply(self, message: discord.Message):
+        # send a message to the channel to acknowledge the receipt of the drawing prompt
+        # add a random sorry message, because this shit is pretty fucking slow.
+        messages = SORRY_SLOW
         if not torch.cuda.is_available():
-            messages += self.NO_GPU
+            # honestly, if you end up here, you'll probably wait an hour for the drawing
+            messages += NO_GPU
 
         sorry = random.choice(messages)
-        await message.channel.send(f'No problem. Just give me some time, {sorry}.')
 
-    async def _make_image(self, prompt: str) -> dict:
-        def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
-            asyncio.set_event_loop(loop)
-            loop.run_forever()
+        queue_pos = ''
+        if self._busy:
+            # if something is in the queue, the worker is currently busy
+            queue_pos = f'waiting for {self._worker_queue.qsize() + 1} other picture(s) tho.'
 
-        async def task(_prompt) -> dict:
-            return make_image(_prompt)
+        await message.channel.send(f'No problem. Just give me some time, {sorry}. {queue_pos}')
 
-        loop = asyncio.new_event_loop()
-        t0 = time.time()
-
-        thread = Thread(target=start_background_loop, args=(loop,), daemon=True)
-        thread.start()
-
-        task = asyncio.run_coroutine_threadsafe(task(prompt), loop)
-        result = task.result()
-
-        loop.stop()
-        print(f'prompt rendered in {time.time() - t0}s')
-        return result
+    async def _queue_prompt(self, message: discord.Message):
+        prompt = message.clean_content[len(self._msg_prefix) + 1:]
+        await self._ack_reply(message)
+        await self._worker_queue.put({
+            'prompt': prompt,
+            'channel_id': message.channel.id,
+            'sender_mention': message.author.mention,
+        })

+ 43 - 3
arty/main.py

@@ -1,5 +1,10 @@
+import asyncio
 import os.path
+from concurrent.futures import ThreadPoolExecutor
+from threading import Thread
 
+from arty.thread_queue import ThreadQueue
+from worker import Worker
 from make_image import make_image
 from utils import get_token
 from bot import Bot
@@ -19,8 +24,43 @@ def _init():
     print('first time setup done.')
 
 
+def run_in_thread(worker, target_loop):
+    def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
+        asyncio.set_event_loop(loop)
+        loop.run_forever()
+
+    async def task():
+        await worker.run()
+
+    thread = Thread(target=start_background_loop, args=(target_loop,), daemon=True)
+    thread.start()
+
+    task = asyncio.run_coroutine_threadsafe(task(), target_loop)
+    task.result()
+
+
+async def _main():
+    worker_queue = ThreadQueue()
+    bot_queue = ThreadQueue()
+
+    loop1 = asyncio.new_event_loop()
+    loop2 = asyncio.new_event_loop()
+
+    worker = Worker(bot_queue, worker_queue)
+    bot = Bot(get_token('discord'), bot_queue, worker_queue)
+
+    executor = ThreadPoolExecutor(max_workers=2)
+
+    loop = asyncio.get_event_loop()
+    await asyncio.wait(
+        fs={
+            loop.run_in_executor(executor, run_in_thread, worker, loop1),
+            loop.run_in_executor(executor, run_in_thread, bot, loop2),
+        },
+        return_when=asyncio.ALL_COMPLETED
+    )
+
+
 if __name__ == '__main__':
     _init()
-
-    bot = Bot(get_token('discord'))
-    bot.run()
+    asyncio.run(_main())

+ 19 - 0
arty/messages.py

@@ -0,0 +1,19 @@
+SORRY_SLOW = [
+    'my server sucks',
+    'my server is slow as fuck',
+    "i'm running on a toaster",
+    "i'm running on a flip-phone",
+    "i'm running on a tamagochi",
+    "i'm running on a potato",
+    "i'm not exactly running on a nasa computer",
+    'i lack the recources to do this fast',
+    "i'm sleepy",
+    "i'm lazy",
+]
+NO_GPU = [
+    'my server lacks a graphics card',
+    "i'm running on a v-server with no gpu",
+    'my creator is too poor to get a server with gpu',
+    'i lack a gpu',
+    "i can't run on cuda (no gpu)"
+]

+ 25 - 0
arty/thread_queue.py

@@ -0,0 +1,25 @@
+import asyncio
+
+
+class ThreadQueue(asyncio.Queue):
+    def __init__(self):
+        super().__init__()
+        self._timeout = 0.02
+
+    async def get(self):
+        while True:
+            try:
+                return self.get_nowait()
+            except asyncio.QueueEmpty:
+                await asyncio.sleep(self._timeout)
+            except Exception as E:
+                raise
+
+    async def put(self, data):
+        while True:
+            try:
+                return self.put_nowait(data)
+            except asyncio.QueueFull:
+                await asyncio.sleep(self._timeout)
+            except Exception as E:
+                raise

+ 50 - 0
arty/worker.py

@@ -0,0 +1,50 @@
+import asyncio
+from io import BytesIO
+
+import discord
+from make_image import make_image
+
+
+class Worker:
+
+    def __init__(self, bot_queue: asyncio.Queue, worker_queue: asyncio.Queue):
+        self._bot_queue = bot_queue
+        self._worker_queue = worker_queue
+
+    async def run(self):
+        while True:
+            message = await self._worker_queue.get()
+            await self._bot_queue.put({'busy': True})
+            await self._solve(message)
+
+            await asyncio.sleep(1)
+
+    async def _solve(self, task_info: dict):
+        image_grid = make_image(task_info['prompt'])
+
+        buffer = BytesIO()
+        image_grid['result'].convert('RGB')
+        image_grid['result'].save(buffer, format='JPEG')
+
+        buffer.seek(0)
+        picture = discord.File(buffer, filename='result.jpeg')
+
+        if image_grid['nsfw'][0]:
+            await self._send(
+                task_info['channel_id'],
+                "oh, this one's spicy (according to the nsfw filter, that i don't care about)",
+            )
+
+        await self._send(
+            task_info['channel_id'],
+            f'{task_info["sender_mention"]} {task_info["prompt"]}:',
+            file=picture,
+        )
+        await self._bot_queue.put({'busy': False})
+
+    async def _send(self, channel_id: int, message: str, **kwargs):
+        await self._bot_queue.put({
+            'channel_id': channel_id,
+            'message': message,
+            **kwargs
+        })