Besides invoke by button, explicitly cal image_length_ratio in gradio

This commit is contained in:
xzqjack 2024-04-11 11:29:08 +08:00
parent 8a81fa6d4e
commit ddba6a4172
3 changed files with 62 additions and 18 deletions

View File

@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
ProjectDir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) ProjectDir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
CheckpointsDir = os.path.join(ProjectDir, "checkpoints") CheckpointsDir = os.path.join(ProjectDir, "checkpoints")
ignore_video2video = False ignore_video2video = False
max_image_edge = 1280
def download_model(): def download_model():
@ -46,6 +47,9 @@ def hf_online_t2v_inference(
video_len, video_len,
img_edge_ratio, img_edge_ratio,
): ):
img_edge_ratio, _, _ = limit_shape(
image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge
)
if not isinstance(image_np, np.ndarray): # None if not isinstance(image_np, np.ndarray): # None
raise gr.Error("Need input reference image") raise gr.Error("Need input reference image")
return online_t2v_inference( return online_t2v_inference(
@ -66,6 +70,9 @@ def hg_online_v2v_inference(
video_length, video_length,
img_edge_ratio, img_edge_ratio,
): ):
img_edge_ratio, _, _ = limit_shape(
image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge
)
if not isinstance(image_np, np.ndarray): # None if not isinstance(image_np, np.ndarray): # None
raise gr.Error("Need input reference image") raise gr.Error("Need input reference image")
return online_v2v_inference( return online_v2v_inference(
@ -82,11 +89,17 @@ def hg_online_v2v_inference(
) )
def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=960): def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=max_image_edge):
"""limite generation video shape to avoid gpu memory overflow""" """limite generation video shape to avoid gpu memory overflow"""
if isinstance(image, np.ndarray) and (input_h == -1 and input_w == -1): if input_h == -1 and input_w == -1:
input_h, input_w, _ = image.shape if isinstance(image, np.ndarray):
h, w, _ = image.shape input_h, input_w, _ = image.shape
elif isinstance(image, PIL.Image.Image):
input_w, input_h = image.size
else:
raise ValueError(
f"image should be in [image, ndarray], but given {type(image)}"
)
if img_edge_ratio == 0: if img_edge_ratio == 0:
img_edge_ratio = 1 img_edge_ratio = 1
img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio) img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio)
@ -235,8 +248,8 @@ with gr.Blocks(css=css) as demo:
"../../data/images/yongen.jpeg", "../../data/images/yongen.jpeg",
], ],
[ [
"(masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3)", "(masterpiece, best quality, highres:1), peaceful beautiful sea scene",
"../../data/images/The-Laughing-Cavalier.jpg", "../../data/images/seaside4.jpeg",
], ],
] ]
with gr.Row(): with gr.Row():
@ -345,6 +358,7 @@ with gr.Blocks(css=css) as demo:
fn=hg_online_v2v_inference, fn=hg_online_v2v_inference,
cache_examples=False, cache_examples=False,
) )
img_edge_ratio.change( img_edge_ratio.change(
fn=limit_shape, fn=limit_shape,
inputs=[image, w, h, img_edge_ratio], inputs=[image, w, h, img_edge_ratio],

View File

@ -2,16 +2,19 @@ import os
import time import time
import pdb import pdb
import PIL.Image
import cuid import cuid
import gradio as gr import gradio as gr
import spaces import spaces
import numpy as np import numpy as np
import PIL
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
ProjectDir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) ProjectDir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
CheckpointsDir = os.path.join(ProjectDir, "checkpoints") CheckpointsDir = os.path.join(ProjectDir, "checkpoints")
ignore_video2video = True ignore_video2video = True
max_image_edge = 960
def download_model(): def download_model():
@ -46,6 +49,9 @@ def hf_online_t2v_inference(
video_len, video_len,
img_edge_ratio, img_edge_ratio,
): ):
img_edge_ratio, _, _ = limit_shape(
image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge
)
if not isinstance(image_np, np.ndarray): # None if not isinstance(image_np, np.ndarray): # None
raise gr.Error("Need input reference image") raise gr.Error("Need input reference image")
return online_t2v_inference( return online_t2v_inference(
@ -66,6 +72,9 @@ def hg_online_v2v_inference(
video_length, video_length,
img_edge_ratio, img_edge_ratio,
): ):
img_edge_ratio, _, _ = limit_shape(
image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge
)
if not isinstance(image_np, np.ndarray): # None if not isinstance(image_np, np.ndarray): # None
raise gr.Error("Need input reference image") raise gr.Error("Need input reference image")
return online_v2v_inference( return online_v2v_inference(
@ -82,11 +91,17 @@ def hg_online_v2v_inference(
) )
def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=960): def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=max_image_edge):
"""limite generation video shape to avoid gpu memory overflow""" """limite generation video shape to avoid gpu memory overflow"""
if isinstance(image, np.ndarray) and (input_h == -1 and input_w == -1): if input_h == -1 and input_w == -1:
input_h, input_w, _ = image.shape if isinstance(image, np.ndarray):
h, w, _ = image.shape input_h, input_w, _ = image.shape
elif isinstance(image, PIL.Image.Image):
input_w, input_h = image.size
else:
raise ValueError(
f"image should be in [image, ndarray], but given {type(image)}"
)
if img_edge_ratio == 0: if img_edge_ratio == 0:
img_edge_ratio = 1 img_edge_ratio = 1
img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio) img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio)
@ -235,8 +250,8 @@ with gr.Blocks(css=css) as demo:
"../../data/images/yongen.jpeg", "../../data/images/yongen.jpeg",
], ],
[ [
"(masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3)", "(masterpiece, best quality, highres:1), peaceful beautiful sea scene",
"../../data/images/The-Laughing-Cavalier.jpg", "../../data/images/seaside4.jpeg",
], ],
] ]
with gr.Row(): with gr.Row():
@ -345,6 +360,7 @@ with gr.Blocks(css=css) as demo:
fn=hg_online_v2v_inference, fn=hg_online_v2v_inference,
cache_examples=False, cache_examples=False,
) )
img_edge_ratio.change( img_edge_ratio.change(
fn=limit_shape, fn=limit_shape,
inputs=[image, w, h, img_edge_ratio], inputs=[image, w, h, img_edge_ratio],

View File

@ -46,6 +46,7 @@ result = subprocess.run(
) )
print(result) print(result)
ignore_video2video = True ignore_video2video = True
max_image_edge = 960
def download_model(): def download_model():
@ -81,6 +82,9 @@ def hf_online_t2v_inference(
video_len, video_len,
img_edge_ratio, img_edge_ratio,
): ):
img_edge_ratio, _, _ = limit_shape(
image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge
)
if not isinstance(image_np, np.ndarray): # None if not isinstance(image_np, np.ndarray): # None
raise gr.Error("Need input reference image") raise gr.Error("Need input reference image")
return online_t2v_inference( return online_t2v_inference(
@ -101,6 +105,9 @@ def hg_online_v2v_inference(
video_length, video_length,
img_edge_ratio, img_edge_ratio,
): ):
img_edge_ratio, _, _ = limit_shape(
image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge
)
if not isinstance(image_np, np.ndarray): # None if not isinstance(image_np, np.ndarray): # None
raise gr.Error("Need input reference image") raise gr.Error("Need input reference image")
return online_v2v_inference( return online_v2v_inference(
@ -117,11 +124,17 @@ def hg_online_v2v_inference(
) )
def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=960): def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=max_image_edge):
"""limite generation video shape to avoid gpu memory overflow""" """limite generation video shape to avoid gpu memory overflow"""
if isinstance(image, np.ndarray) and (input_h == -1 and input_w == -1): if input_h == -1 and input_w == -1:
input_h, input_w, _ = image.shape if isinstance(image, np.ndarray):
h, w, _ = image.shape input_h, input_w, _ = image.shape
elif isinstance(image, PIL.Image.Image):
input_w, input_h = image.size
else:
raise ValueError(
f"image should be in [image, ndarray], but given {type(image)}"
)
if img_edge_ratio == 0: if img_edge_ratio == 0:
img_edge_ratio = 1 img_edge_ratio = 1
img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio) img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio)
@ -270,8 +283,8 @@ with gr.Blocks(css=css) as demo:
"../../data/images/yongen.jpeg", "../../data/images/yongen.jpeg",
], ],
[ [
"(masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3)", "(masterpiece, best quality, highres:1), peaceful beautiful sea scene",
"../../data/images/The-Laughing-Cavalier.jpg", "../../data/images/seaside4.jpeg",
], ],
] ]
with gr.Row(): with gr.Row():
@ -380,6 +393,7 @@ with gr.Blocks(css=css) as demo:
fn=hg_online_v2v_inference, fn=hg_online_v2v_inference,
cache_examples=False, cache_examples=False,
) )
img_edge_ratio.change( img_edge_ratio.change(
fn=limit_shape, fn=limit_shape,
inputs=[image, w, h, img_edge_ratio], inputs=[image, w, h, img_edge_ratio],