Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add suppport for i2vgen-xl im2vid #56

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 31 additions & 5 deletions runner/app/pipelines/image_to_video.py
@@ -1,7 +1,7 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir

from diffusers import StableVideoDiffusionPipeline
from diffusers import StableVideoDiffusionPipeline, I2VGenXLPipeline
from huggingface_hub import file_download
import torch
import PIL
Expand All @@ -15,6 +15,8 @@

logger = logging.getLogger(__name__)

I2VGEN_LIGHTNING_MODEL_ID = "ali-vilab/i2vgen-xl"
SVD_LIGHTNING_MODEL_ID = "stabilityai/stable-video-diffusion-img2vid-xt-1-1"

class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
Expand All @@ -37,7 +39,11 @@ def __init__(self, model_id: str):
kwargs["variant"] = "fp16"

self.model_id = model_id
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)

if I2VGEN_LIGHTNING_MODEL_ID in model_id:
self.ldm = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
else:
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())

if os.environ.get("SFAST"):
Expand All @@ -50,9 +56,6 @@ def __init__(self, model_id: str):
self.ldm = compile_model(self.ldm)

def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
if "decode_chunk_size" not in kwargs:
kwargs["decode_chunk_size"] = 4

seed = kwargs.pop("seed", None)
if seed is not None:
if isinstance(seed, int):
Expand All @@ -64,6 +67,29 @@ def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]:
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

if SVD_LIGHTNING_MODEL_ID in self.model_id:
if "decode_chunk_size" not in kwargs:
kwargs["decode_chunk_size"] = 4
if "prompt" in kwargs:
del kwargs["prompt"]
elif I2VGEN_LIGHTNING_MODEL_ID in self.model_id:
kwargs["num_inference_steps"] = 75
if "decode_chunk_size" not in kwargs:
kwargs["decode_chunk_size"] = 8
if "num_frames" not in kwargs:
kwargs["num_frames"] = 25
if "fps" in kwargs:
del kwargs["fps"]
if "motion_bucket_id" in kwargs:
del kwargs["motion_bucket_id"]
if "noise_aug_strength" in kwargs:
del kwargs["noise_aug_strength"]
prompt = ""
if "prompt" in kwargs:
prompt = kwargs["prompt"]
del kwargs["prompt"]
return self.ldm(prompt, image, **kwargs).frames

return self.ldm(image, **kwargs).frames

def __str__(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions runner/app/routes/image_to_video.py
Expand Up @@ -32,6 +32,7 @@
)
async def image_to_video(
image: Annotated[UploadFile, File()],
prompt: Annotated[str, Form()] = "",
model_id: Annotated[str, Form()] = "",
height: Annotated[int, Form()] = 576,
width: Annotated[int, Form()] = 1024,
Expand Down Expand Up @@ -74,6 +75,7 @@ async def image_to_video(
batch_frames = pipeline(
image=Image.open(image.file).convert("RGB"),
height=height,
prompt=prompt,
width=width,
fps=fps,
motion_bucket_id=motion_bucket_id,
Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Expand Up @@ -80,6 +80,9 @@ else
# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download video models
huggingface-cli download ali-vilab/i2vgen-xl --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
check_hf_auth
Expand Down
5 changes: 5 additions & 0 deletions runner/openapi.json
Expand Up @@ -289,6 +289,11 @@
"title": "Model Id",
"default": ""
},
"prompt": {
"type": "string",
"title": "Prompt",
"default": ""
},
"height": {
"type": "integer",
"title": "Height",
Expand Down
39 changes: 20 additions & 19 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.