Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment with combined base+refiner for SDXL #55

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions runner/app/pipelines/text_to_image.py
Expand Up @@ -6,6 +6,7 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
StableDiffusionXLImg2ImgPipeline
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
Expand All @@ -18,6 +19,7 @@
logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"
SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"


class TextToImagePipeline(Pipeline):
Expand Down Expand Up @@ -90,6 +92,20 @@ def __init__(self, model_id: str):
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif SDXL_BASE_MODEL_ID in self.model_id:
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
kwargs["use_safetensors"] = True
self.ldm = StableDiffusionXLPipeline.from_pretrained(model_id, **kwargs).to("cuda")
self.refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.ldm.text_encoder_2,
vae=self.ldm.vae,
torch_dtype=kwargs["torch_dtype"],
use_safetensors=True,
variant=kwargs["variant"],
).to("cuda")

else:
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to(
torch_device
Expand Down Expand Up @@ -156,6 +172,16 @@ def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
else:
# Default to 2step
kwargs["num_inference_steps"] = 2
elif SDXL_BASE_MODEL_ID in self.model_id:
kwargs["num_inference_steps"] = 40
kwargs["denoising_end"] = 0.8
kwargs["output_type"] = "latent"
image = self.ldm(prompt, **kwargs).images
del kwargs["output_type"]
del kwargs["denoising_end"]
kwargs["image"] = image
kwargs["denoising_start"] = 0.8
return self.refiner(prompt, **kwargs).images

return self.ldm(prompt, **kwargs).images

Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Expand Up @@ -80,6 +80,9 @@ else
# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download some refiner models
huggingface-cli download stabilityai/stable-diffusion-xl-refiner-1.0 --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
check_hf_auth
Expand Down