<Update> pipeline_controlnet.py add latent split logic (#68)
* <Update> pipeline_controlnet.py add latent split logic * translate chineses anno into english --------- Co-authored-by: xzqjack <xzqjack@hotmail.com>
This commit is contained in:
parent
cc5ff21a59
commit
f3aec3fbf7
@ -1381,6 +1381,7 @@ class MusevControlNetPipeline(
|
||||
context_batch_size=1,
|
||||
interpolation_factor=1,
|
||||
# parallel_denoise parameter end
|
||||
decoder_t_segment: int = 200,
|
||||
):
|
||||
r"""
|
||||
旨在兼容text2video、text2image、img2img、video2video、是否有controlnet等的通用pipeline。目前仅不支持img2img、video2video。
|
||||
@ -2153,7 +2154,21 @@ class MusevControlNetPipeline(
|
||||
data2_index=latent_index,
|
||||
dim=2,
|
||||
)
|
||||
video = self.decode_latents(latents)
|
||||
b, c, t, h, w = latents.shape
|
||||
num_segments = (t + decoder_t_segment - 1) // decoder_t_segment
|
||||
|
||||
video_segments = []
|
||||
# to avoid t chanel too large causing gpu memory error
|
||||
# split video latents in slices along t channel, decode each slice, and then concatenate them
|
||||
for i in range(num_segments):
|
||||
logger.debug(f"Decoding {i} th segment")
|
||||
start_t = i * decoder_t_segment
|
||||
end_t = min((i + 1) * decoder_t_segment, t)
|
||||
latents_segment = latents[:, :, start_t:end_t, :, :]
|
||||
video_segment = self.decode_latents(latents_segment)
|
||||
video_segments.append(video_segment)
|
||||
video_segments_np = np.concatenate(video_segments, axis=2)
|
||||
video = torch.from_numpy(video_segments_np)
|
||||
|
||||
if skip_temporal_layer:
|
||||
self.unet.set_skip_temporal_layers(False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user