Skip to content

Commit

Permalink
multi: Multi-O for image-to-video
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 19, 2024
1 parent 2c7b6bb commit a436dc3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
10 changes: 10 additions & 0 deletions cmd/livepeer/starter/starter.go
Expand Up @@ -567,6 +567,16 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
}

constraints[core.Capability_ImageToImage].Models[config.ModelID] = modelConstraint
case "image-to-video":
_, ok := constraints[core.Capability_ImageToVideo]
if !ok {
aiCaps = append(aiCaps, core.Capability_ImageToVideo)
constraints[core.Capability_ImageToVideo] = &core.Constraints{
Models: make(map[string]*core.ModelConstraint),
}
}

constraints[core.Capability_ImageToVideo].Models[config.ModelID] = modelConstraint
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Expand Up @@ -69,6 +69,7 @@ const (
Capability_SegmentSlicing
Capability_TextToImage
Capability_ImageToImage
Capability_ImageToVideo
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -102,6 +103,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_SegmentSlicing: "Segment slicing",
Capability_TextToImage: "Text to image",
Capability_ImageToImage: "Image to image",
Capability_ImageToVideo: "Image to video",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -189,6 +191,7 @@ func OptionalCapabilities() []Capability {
Capability_H264_Decode_420_10bit,
Capability_TextToImage,
Capability_ImageToImage,
Capability_ImageToVideo,
}
}

Expand Down
45 changes: 27 additions & 18 deletions server/ai_process.go
Expand Up @@ -10,9 +10,7 @@ import (
"path/filepath"
"sort"
"strings"
"time"

"github.com/cenkalti/backoff"
"github.com/livepeer/ai-worker/worker"
"github.com/livepeer/go-livepeer/clog"
"github.com/livepeer/go-livepeer/common"
Expand All @@ -21,11 +19,10 @@ import (
"github.com/livepeer/go-tools/drivers"
)

const imageToVideoTimeout = 5 * time.Minute
const imageToVideoRetryBackoff = 1 * time.Minute
const maxProcessingRetries = 4
const defaultTextToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt"

type ServiceUnavailableError struct {
err error
Expand Down Expand Up @@ -255,33 +252,45 @@ func submitImageToImage(ctx context.Context, url string, req worker.ImageToImage
}

func processImageToVideo(ctx context.Context, params aiRequestParams, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// Discover 1 orchestrator
// TODO: Discover multiple orchestrators
caps := core.NewCapabilities(core.DefaultCapabilities(), nil)
orchDesc, err := params.node.OrchestratorPool.GetOrchestrators(ctx, 1, newSuspender(), caps, common.ScoreAtLeast(0))
modelID := defaultImageToVideoModelID
if req.ModelId != nil {
modelID = *req.ModelId
}

orchInfos, err := getOrchestratorsForAIRequest(ctx, params, core.Capability_ImageToVideo, modelID)
if err != nil {
return nil, err
}
orchInfos := orchDesc.GetRemoteInfos()

if len(orchInfos) == 0 {
return nil, errors.New("no orchestrators available")
}

orchUrl := orchInfos[0].Transcoder

var resp *worker.ImageResponse
op := func() error {

// Round robin up to maxProcessingRetries times
orchIdx := 0
tries := 0
for tries < maxProcessingRetries {
orchUrl := orchInfos[orchIdx].Transcoder

var err error
resp, err = submitImageToVideo(ctx, orchUrl, req)
return err
}
notify := func(err error, dur time.Duration) {
clog.Infof(ctx, "Error submitting ImageToVideo request err=%v retrying after dur=%v", err, dur)
if err == nil {
break
}

clog.Infof(ctx, "Error submitting ImageToVideo request try=%v orch=%v err=%v", tries, orchUrl, err)

tries++
orchIdx++
// Wrap back around
if orchIdx >= len(orchInfos) {
orchIdx = 0
}
}

b := backoff.WithMaxRetries(backoff.NewConstantBackOff(imageToVideoRetryBackoff), maxProcessingRetries)
if err := backoff.RetryNotify(op, b, notify); err != nil {
if resp == nil {
return nil, &ServiceUnavailableError{err: err}
}

Expand Down

0 comments on commit a436dc3

Please sign in to comment.