Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process text-to-image requested image count sequentially #66

Merged
merged 2 commits into from May 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 26 additions & 27 deletions runner/app/routes/text_to_image.py
Expand Up @@ -7,7 +7,7 @@
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
import logging
import random
import os
import os, json

router = APIRouter()

Expand Down Expand Up @@ -57,33 +57,32 @@ async def text_to_image(
),
)

if params.seed is None:
params.seed = random.randint(0, 2**32 - 1)
if params.num_images_per_prompt > 1:
params.seed = [
i for i in range(params.seed, params.seed + params.num_images_per_prompt)
]
seed = params.seed if params.seed is not None else random.randint(0, 2**32 - 1)
seeds = [seed + i for i in range(params.num_images_per_prompt)]

try:
images, has_nsfw_concept = pipeline(**params.model_dump())
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("TextToImagePipeline error")
)

seeds = params.seed
if not isinstance(seeds, list):
seeds = [seeds]
# TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again
# once LIV-243 and LIV-379 are resolved.
images = []
has_nsfw_concept = []
params.num_images_per_prompt = 1
for seed in seeds:
try:
params.seed = [seed]
imgs, nsfw_check = pipeline(**params.model_dump())
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("TextToImagePipeline error")
)

output_images = []
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
is_nsfw = is_nsfw or False
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
output_images = [
{"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False}
for img, sd, nsfw in zip(images, seeds, has_nsfw_concept)
]

return {"images": output_images}