|
@@ -27,7 +27,7 @@ class Unsafe(StableDiffusionPipeline):
|
|
|
self._patched = True
|
|
self._patched = True
|
|
|
|
|
|
|
|
def _to_grid(self, images):
|
|
def _to_grid(self, images):
|
|
|
- rows = 2
|
|
|
|
|
|
|
+ rows = math.floor(math.sqrt(len(images)))
|
|
|
cols = math.ceil(len(images) / rows)
|
|
cols = math.ceil(len(images) / rows)
|
|
|
|
|
|
|
|
w, h = images[0].size
|
|
w, h = images[0].size
|
|
@@ -40,7 +40,7 @@ class Unsafe(StableDiffusionPipeline):
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
def __call__(self, *args, **kwargs):
|
|
|
num_images = kwargs.pop('num_images', 4)
|
|
num_images = kwargs.pop('num_images', 4)
|
|
|
- print(f'rendering {num_images} images')
|
|
|
|
|
|
|
+ print('rendering...')
|
|
|
|
|
|
|
|
self._patch_checker()
|
|
self._patch_checker()
|
|
|
has_nsfw = False
|
|
has_nsfw = False
|
|
@@ -49,43 +49,47 @@ class Unsafe(StableDiffusionPipeline):
|
|
|
for _i in range(num_images):
|
|
for _i in range(num_images):
|
|
|
result = super(Unsafe, self).__call__(*args, **kwargs)
|
|
result = super(Unsafe, self).__call__(*args, **kwargs)
|
|
|
samples.append(result['sample'][0])
|
|
samples.append(result['sample'][0])
|
|
|
-
|
|
|
|
|
- has_nsfw = has_nsfw or result['nsfw_content_detected']
|
|
|
|
|
|
|
+ has_nsfw = has_nsfw or result['nsfw_content_detected'][0]
|
|
|
|
|
|
|
|
return {'result': self._to_grid(samples), 'nsfw': has_nsfw}
|
|
return {'result': self._to_grid(samples), 'nsfw': has_nsfw}
|
|
|
|
|
|
|
|
|
|
|
|
|
-def make_image(prompt: str, num_images=4):
|
|
|
|
|
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
+class Renderer:
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
- args_by_device = {
|
|
|
|
|
- 'cuda': {
|
|
|
|
|
- 'torch_dtype': torch.float16,
|
|
|
|
|
- 'revision': 'fp16',
|
|
|
|
|
- },
|
|
|
|
|
- 'cpu': {
|
|
|
|
|
- 'torch_dtype': torch.float32,
|
|
|
|
|
- 'revision': 'main',
|
|
|
|
|
|
|
+ args_by_device = {
|
|
|
|
|
+ 'cuda': {
|
|
|
|
|
+ 'torch_dtype': torch.float16,
|
|
|
|
|
+ 'revision': 'fp16',
|
|
|
|
|
+ },
|
|
|
|
|
+ 'cpu': {
|
|
|
|
|
+ 'torch_dtype': torch.float32,
|
|
|
|
|
+ 'revision': 'main',
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
- pretrained_args = args_by_device[device]
|
|
|
|
|
-
|
|
|
|
|
- lms = LMSDiscreteScheduler(
|
|
|
|
|
- beta_start=0.00085,
|
|
|
|
|
- beta_end=0.012,
|
|
|
|
|
- beta_schedule="scaled_linear"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- local_cache = os.path.join(os.path.dirname(__file__), '..', 'model_data')
|
|
|
|
|
- cache_dir = '/model_data' if os.path.isdir('/model_data') else local_cache
|
|
|
|
|
- pipe = Unsafe.from_pretrained(
|
|
|
|
|
- "CompVis/stable-diffusion-v1-4",
|
|
|
|
|
- use_auth_token=get_token('huggingface'),
|
|
|
|
|
- cache_dir=cache_dir,
|
|
|
|
|
- scheduler=lms,
|
|
|
|
|
- **pretrained_args,
|
|
|
|
|
- ).to(device)
|
|
|
|
|
-
|
|
|
|
|
- # autocasting to cuda even when running on CPU. no idea why, but that works
|
|
|
|
|
- with autocast('cuda'):
|
|
|
|
|
- return pipe(prompt, num_images=num_images)
|
|
|
|
|
|
|
+ pretrained_args = args_by_device[device]
|
|
|
|
|
+
|
|
|
|
|
+ lms = LMSDiscreteScheduler(
|
|
|
|
|
+ beta_start=0.00085,
|
|
|
|
|
+ beta_end=0.012,
|
|
|
|
|
+ beta_schedule="scaled_linear"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ local_cache = os.path.join(os.path.dirname(__file__), '..', 'model_data')
|
|
|
|
|
+ cache_dir = '/model_data' if os.path.isdir('/model_data') else local_cache
|
|
|
|
|
+ self.pipe = Unsafe.from_pretrained(
|
|
|
|
|
+ "CompVis/stable-diffusion-v1-4",
|
|
|
|
|
+ use_auth_token=get_token('huggingface'),
|
|
|
|
|
+ cache_dir=cache_dir,
|
|
|
|
|
+ scheduler=lms,
|
|
|
|
|
+ **pretrained_args,
|
|
|
|
|
+ ).to(device)
|
|
|
|
|
+
|
|
|
|
|
+ def make_image(self, prompt: str, num_images=4):
|
|
|
|
|
+ try:
|
|
|
|
|
+ # autocasting to cuda even when running on CPU. no idea why, but that works
|
|
|
|
|
+ with autocast('cuda'):
|
|
|
|
|
+ return self.pipe(prompt, num_images=num_images)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(e)
|