bot.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import asyncio
  2. import random
  3. import time
  4. from io import BytesIO
  5. from threading import Thread
  6. import discord
  7. import torch
  8. from make_image import make_image
  9. class Bot:
  10. SORRY = [
  11. 'my server sucks',
  12. 'my server is slow as fuck',
  13. "i'm running on a toaster",
  14. "i'm running on a flip-phone",
  15. "i'm running on a tamagochi",
  16. "i'm running on a potato",
  17. "i'm not exactly running on a nasa computer",
  18. 'i lack the recources to do this fast',
  19. "i'm sleepy",
  20. "i'm lazy",
  21. ]
  22. NO_GPU = [
  23. 'my server lacks a graphics card',
  24. "i'm running on a v-server with no gpu",
  25. 'my creator is too poor to get a server with gpu',
  26. 'i lack a gpu',
  27. "i can't run on cuda (no gpu)"
  28. ]
  29. def __init__(self, secret: str):
  30. self._client = None
  31. self._secret = secret
  32. self._setup()
  33. def run(self):
  34. # note: blocking!
  35. self._client.run(self._secret)
  36. async def on_ready(self):
  37. print(f'We have logged in as {self._client.user}')
  38. async def on_message(self, message):
  39. print(f'{message.author}: {message.clean_content} ({message.mentions})')
  40. if message.author == self._client.user:
  41. return
  42. if message.clean_content.startswith(self._msg_prefix):
  43. await self._draw_image(message)
  44. @property
  45. def _msg_prefix(self):
  46. return f'@{self._client.user.name} draw'
  47. def _setup(self):
  48. intents = discord.Intents.default()
  49. intents.message_content = True
  50. self._client = discord.Client(intents=intents)
  51. self._client.event(self.on_ready)
  52. self._client.event(self.on_message)
  53. async def _draw_image(self, message: discord.Message):
  54. await self._ack_prompt(message)
  55. prompt = message.clean_content[len(self._msg_prefix) + 1:]
  56. print(f'drawing an image, prompt="{prompt}"')
  57. image_grid = await self._make_image(prompt)
  58. buffer = BytesIO()
  59. image_grid['result'].convert('RGB')
  60. image_grid['result'].save(buffer, format='JPEG')
  61. buffer.seek(0)
  62. picture = discord.File(buffer, filename='result.jpeg')
  63. if image_grid['nsfw'][0]:
  64. await message.channel.send(
  65. "oh, this one's spicy (according to the nsfw filter, that i don't care about)",
  66. )
  67. await message.channel.send(f'{message.author.mention} {prompt}:', file=picture)
  68. async def _ack_prompt(self, message: discord.Message):
  69. messages = self.SORRY
  70. if not torch.cuda.is_available():
  71. messages += self.NO_GPU
  72. sorry = random.choice(messages)
  73. await message.channel.send(f'No problem. Just give me some time, {sorry}.')
  74. async def _make_image(self, prompt: str) -> dict:
  75. def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
  76. asyncio.set_event_loop(loop)
  77. loop.run_forever()
  78. async def task(_prompt) -> dict:
  79. return make_image(_prompt)
  80. loop = asyncio.new_event_loop()
  81. t0 = time.time()
  82. thread = Thread(target=start_background_loop, args=(loop,), daemon=True)
  83. thread.start()
  84. task = asyncio.run_coroutine_threadsafe(task(prompt), loop)
  85. result = task.result()
  86. loop.stop()
  87. print(f'prompt rendered in {time.time() - t0}s')
  88. return result