2
0
Эх сурвалжийг харах

Mostly performance improvements

Locutus 3 жил өмнө
parent
commit
0ee2dea2ad
6 өөрчлөгдсөн 89 нэмэгдсэн , 47 устгасан
  1. 22 5
      arty/bot.py
  2. 19 1
      arty/messages.py
  3. 40 36
      arty/renderer.py
  4. 5 3
      arty/worker.py
  5. 3 2
      main.py
  6. 0 0
      sample.py

+ 22 - 5
arty/bot.py

@@ -4,7 +4,7 @@ import random
 import discord
 import discord
 import torch
 import torch
 
 
-from arty.messages import SORRY_SLOW, NO_GPU
+from arty.messages import SORRY_SLOW, NO_GPU, LIVENESS, LIVENESS_RESPONSE
 from arty.thread_queue import ThreadQueue
 from arty.thread_queue import ThreadQueue
 
 
 
 
@@ -42,7 +42,10 @@ class Bot:
     async def on_message(self, message: discord.Message):
     async def on_message(self, message: discord.Message):
         if message.clean_content.startswith(self._msg_prefix):
         if message.clean_content.startswith(self._msg_prefix):
             self._channels[message.channel.id] = message.channel
             self._channels[message.channel.id] = message.channel
-            await self._queue_prompt(message)
+            return await self._queue_prompt(message)
+
+        if self._is_liveness_check(message.clean_content):
+            return await self._answer_liveness_check(message)
 
 
     async def _work_queue(self):
     async def _work_queue(self):
         while True:
         while True:
@@ -52,9 +55,12 @@ class Bot:
                 self._busy = busy
                 self._busy = busy
                 continue
                 continue
 
 
-            channel = self._channels[message.pop('channel_id')]
-            text = message.pop('message')
-            await channel.send(text, **message)
+            try:
+                channel = self._channels[message.pop('channel_id')]
+                text = message.pop('message')
+                await channel.send(text, **message)
+            except Exception as e:
+                print(e)
 
 
     async def _start(self):
     async def _start(self):
         async with self._client:
         async with self._client:
@@ -97,3 +103,14 @@ class Bot:
             'channel_id': message.channel.id,
             'channel_id': message.channel.id,
             'sender_mention': message.author.mention,
             'sender_mention': message.author.mention,
         })
         })
+
+    def _is_liveness_check(self, content: str):
+        for message in LIVENESS:
+            if content.startswith(f'@{self._client.user.name} {message}'):
+                return True
+        return False
+
+    @staticmethod
+    async def _answer_liveness_check(message: discord.Message):
+        response = random.choice(LIVENESS_RESPONSE)
+        await message.channel.send(response)

+ 19 - 1
arty/messages.py

@@ -6,10 +6,11 @@ SORRY_SLOW = [
     "i'm running on a tamagochi",
     "i'm running on a tamagochi",
     "i'm running on a potato",
     "i'm running on a potato",
     "i'm not exactly running on a nasa computer",
     "i'm not exactly running on a nasa computer",
-    'i lack the recources to do this fast',
+    'i lack the resources to do this fast',
     "i'm sleepy",
     "i'm sleepy",
     "i'm lazy",
     "i'm lazy",
 ]
 ]
+
 NO_GPU = [
 NO_GPU = [
     'my server lacks a graphics card',
     'my server lacks a graphics card',
     "i'm running on a v-server with no gpu",
     "i'm running on a v-server with no gpu",
@@ -17,3 +18,20 @@ NO_GPU = [
     'i lack a gpu',
     'i lack a gpu',
     "i can't run on cuda (no gpu)"
     "i can't run on cuda (no gpu)"
 ]
 ]
+
+LIVENESS = [
+    'u there?',
+    'hello?',
+    'are you there?',
+    'how are you?',
+]
+
+LIVENESS_RESPONSE = [
+    "i'm ready to take your order",
+    "i'm ready to take requests",
+    'always happy to help ya :)',
+    'here!',
+    'how may i assist?',
+    "i'm there!",
+    "i'm there and ready to draw something!",
+]

+ 40 - 36
arty/make_image.py → arty/renderer.py

@@ -27,7 +27,7 @@ class Unsafe(StableDiffusionPipeline):
         self._patched = True
         self._patched = True
 
 
     def _to_grid(self, images):
     def _to_grid(self, images):
-        rows = 2
+        rows = math.floor(math.sqrt(len(images)))
         cols = math.ceil(len(images) / rows)
         cols = math.ceil(len(images) / rows)
 
 
         w, h = images[0].size
         w, h = images[0].size
@@ -40,7 +40,7 @@ class Unsafe(StableDiffusionPipeline):
 
 
     def __call__(self, *args, **kwargs):
     def __call__(self, *args, **kwargs):
         num_images = kwargs.pop('num_images', 4)
         num_images = kwargs.pop('num_images', 4)
-        print(f'rendering {num_images} images')
+        print('rendering...')
 
 
         self._patch_checker()
         self._patch_checker()
         has_nsfw = False
         has_nsfw = False
@@ -49,43 +49,47 @@ class Unsafe(StableDiffusionPipeline):
         for _i in range(num_images):
         for _i in range(num_images):
             result = super(Unsafe, self).__call__(*args, **kwargs)
             result = super(Unsafe, self).__call__(*args, **kwargs)
             samples.append(result['sample'][0])
             samples.append(result['sample'][0])
-
-            has_nsfw = has_nsfw or result['nsfw_content_detected']
+            has_nsfw = has_nsfw or result['nsfw_content_detected'][0]
 
 
         return {'result': self._to_grid(samples), 'nsfw': has_nsfw}
         return {'result': self._to_grid(samples), 'nsfw': has_nsfw}
 
 
 
 
-def make_image(prompt: str, num_images=4):
-    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+class Renderer:
+    def __init__(self):
+        device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
-    args_by_device = {
-        'cuda': {
-            'torch_dtype': torch.float16,
-            'revision': 'fp16',
-        },
-        'cpu': {
-            'torch_dtype': torch.float32,
-            'revision': 'main',
+        args_by_device = {
+            'cuda': {
+                'torch_dtype': torch.float16,
+                'revision': 'fp16',
+            },
+            'cpu': {
+                'torch_dtype': torch.float32,
+                'revision': 'main',
+            }
         }
         }
-    }
-    pretrained_args = args_by_device[device]
-
-    lms = LMSDiscreteScheduler(
-        beta_start=0.00085,
-        beta_end=0.012,
-        beta_schedule="scaled_linear"
-    )
-
-    local_cache = os.path.join(os.path.dirname(__file__), '..', 'model_data')
-    cache_dir = '/model_data' if os.path.isdir('/model_data') else local_cache
-    pipe = Unsafe.from_pretrained(
-        "CompVis/stable-diffusion-v1-4",
-        use_auth_token=get_token('huggingface'),
-        cache_dir=cache_dir,
-        scheduler=lms,
-        **pretrained_args,
-    ).to(device)
-
-    # autocasting to cuda even when running on CPU. no idea why, but that works
-    with autocast('cuda'):
-        return pipe(prompt, num_images=num_images)
+        pretrained_args = args_by_device[device]
+
+        lms = LMSDiscreteScheduler(
+            beta_start=0.00085,
+            beta_end=0.012,
+            beta_schedule="scaled_linear"
+        )
+
+        local_cache = os.path.join(os.path.dirname(__file__), '..', 'model_data')
+        cache_dir = '/model_data' if os.path.isdir('/model_data') else local_cache
+        self.pipe = Unsafe.from_pretrained(
+            "CompVis/stable-diffusion-v1-4",
+            use_auth_token=get_token('huggingface'),
+            cache_dir=cache_dir,
+            scheduler=lms,
+            **pretrained_args,
+        ).to(device)
+
+    def make_image(self, prompt: str, num_images=4):
+        try:
+            # autocasting to cuda even when running on CPU. no idea why, but that works
+            with autocast('cuda'):
+                return self.pipe(prompt, num_images=num_images)
+        except Exception as e:
+            print(e)

+ 5 - 3
arty/worker.py

@@ -3,7 +3,7 @@ from io import BytesIO
 
 
 import discord
 import discord
 
 
-from arty.make_image import make_image
+from arty.renderer import Renderer
 
 
 
 
 class Worker:
 class Worker:
@@ -11,6 +11,7 @@ class Worker:
     def __init__(self, bot_queue: asyncio.Queue, worker_queue: asyncio.Queue):
     def __init__(self, bot_queue: asyncio.Queue, worker_queue: asyncio.Queue):
         self._bot_queue = bot_queue
         self._bot_queue = bot_queue
         self._worker_queue = worker_queue
         self._worker_queue = worker_queue
+        self._renderer = Renderer()
 
 
     async def run(self):
     async def run(self):
         while True:
         while True:
@@ -21,16 +22,17 @@ class Worker:
             await asyncio.sleep(1)
             await asyncio.sleep(1)
 
 
     async def _solve(self, task_info: dict):
     async def _solve(self, task_info: dict):
-        image_grid = make_image(task_info['prompt'])
+        image_grid = self._renderer.make_image(task_info['prompt'])
 
 
         buffer = BytesIO()
         buffer = BytesIO()
         image_grid['result'].convert('RGB')
         image_grid['result'].convert('RGB')
         image_grid['result'].save(buffer, format='JPEG')
         image_grid['result'].save(buffer, format='JPEG')
 
 
         buffer.seek(0)
         buffer.seek(0)
+
         picture = discord.File(buffer, filename='result.jpeg')
         picture = discord.File(buffer, filename='result.jpeg')
 
 
-        if image_grid['nsfw'][0]:
+        if image_grid['nsfw']:
             await self._send(
             await self._send(
                 task_info['channel_id'],
                 task_info['channel_id'],
                 "oh, this one's spicy (according to the nsfw filter, that i don't care about)",
                 "oh, this one's spicy (according to the nsfw filter, that i don't care about)",

+ 3 - 2
main.py

@@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
 
 
 from arty.thread_queue import ThreadQueue
 from arty.thread_queue import ThreadQueue
 from arty.worker import Worker
 from arty.worker import Worker
-from arty.make_image import make_image
+from arty.renderer import Renderer
 from arty.utils import get_token, run_in_thread
 from arty.utils import get_token, run_in_thread
 from arty.bot import Bot
 from arty.bot import Bot
 
 
@@ -15,7 +15,8 @@ def _init():
         return
         return
 
 
     print('running first time setup. this will take some time.')
     print('running first time setup. this will take some time.')
-    make_image('something cute', num_images=1)
+    renderer = Renderer()
+    renderer.make_image('something cute', num_images=1)
 
 
     with open(lock, 'w') as f:
     with open(lock, 'w') as f:
         f.write('done')
         f.write('done')

+ 0 - 0
sample.py