<Update> pipeline_controlnet.py add latent split logic (#68)

* <Update> pipeline_controlnet.py add latent split logic

* translate chineses anno into english

---------

Co-authored-by: xzqjack <xzqjack@hotmail.com>
This commit is contained in:
Oli_Zhan 2024-04-13 00:31:51 +08:00 committed by GitHub
parent cc5ff21a59
commit f3aec3fbf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1381,6 +1381,7 @@ class MusevControlNetPipeline(
context_batch_size=1,
interpolation_factor=1,
# parallel_denoise parameter end
decoder_t_segment: int = 200,
):
r"""
旨在兼容text2videotext2imageimg2imgvideo2video是否有controlnet等的通用pipeline目前仅不支持img2imgvideo2video
@ -2153,7 +2154,21 @@ class MusevControlNetPipeline(
data2_index=latent_index,
dim=2,
)
video = self.decode_latents(latents)
b, c, t, h, w = latents.shape
num_segments = (t + decoder_t_segment - 1) // decoder_t_segment
video_segments = []
# to avoid t chanel too large causing gpu memory error
# split video latents in slices along t channel, decode each slice, and then concatenate them
for i in range(num_segments):
logger.debug(f"Decoding {i} th segment")
start_t = i * decoder_t_segment
end_t = min((i + 1) * decoder_t_segment, t)
latents_segment = latents[:, :, start_t:end_t, :, :]
video_segment = self.decode_latents(latents_segment)
video_segments.append(video_segment)
video_segments_np = np.concatenate(video_segments, axis=2)
video = torch.from_numpy(video_segments_np)
if skip_temporal_layer:
self.unet.set_skip_temporal_layers(False)