Some checks failed
Close inactive issues / close-issues (push) Has been cancelled
131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
import io
|
|
from hashlib import sha256
|
|
from pathlib import Path
|
|
from typing import Callable, Literal, Tuple
|
|
|
|
import torch
|
|
import torchaudio
|
|
from loguru import logger
|
|
|
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
from fish_speech.utils.file import (
|
|
AUDIO_EXTENSIONS,
|
|
audio_to_bytes,
|
|
list_files,
|
|
read_ref_text,
|
|
)
|
|
from fish_speech.utils.schema import ServeReferenceAudio
|
|
|
|
|
|
class ReferenceLoader:
|
|
|
|
def __init__(self) -> None:
|
|
"""
|
|
Component of the TTSInferenceEngine class.
|
|
Loads and manages the cache for the reference audio and text.
|
|
"""
|
|
self.ref_by_id: dict = {}
|
|
self.ref_by_hash: dict = {}
|
|
|
|
# Make Pylance happy (attribut/method not defined...)
|
|
self.decoder_model: FireflyArchitecture
|
|
self.encode_reference: Callable
|
|
|
|
# Define the torchaudio backend
|
|
backends = torchaudio.list_audio_backends()
|
|
if "ffmpeg" in backends:
|
|
self.backend = "ffmpeg"
|
|
else:
|
|
self.backend = "soundfile"
|
|
|
|
def load_by_id(
|
|
self,
|
|
id: str,
|
|
use_cache: Literal["on", "off"],
|
|
) -> Tuple:
|
|
|
|
# Load the references audio and text by id
|
|
ref_folder = Path("references") / id
|
|
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
ref_audios = list_files(
|
|
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
)
|
|
|
|
if use_cache == "off" or id not in self.ref_by_id:
|
|
# If the references are not already loaded, encode them
|
|
prompt_tokens = [
|
|
self.encode_reference(
|
|
# decoder_model=self.decoder_model,
|
|
reference_audio=audio_to_bytes(str(ref_audio)),
|
|
enable_reference_audio=True,
|
|
)
|
|
for ref_audio in ref_audios
|
|
]
|
|
prompt_texts = [
|
|
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
for ref_audio in ref_audios
|
|
]
|
|
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
|
|
|
|
else:
|
|
# Reuse already encoded references
|
|
logger.info("Use same references")
|
|
prompt_tokens, prompt_texts = self.ref_by_id[id]
|
|
|
|
return prompt_tokens, prompt_texts
|
|
|
|
def load_by_hash(
|
|
self,
|
|
references: list[ServeReferenceAudio],
|
|
use_cache: Literal["on", "off"],
|
|
) -> Tuple:
|
|
|
|
# Load the references audio and text by hash
|
|
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
|
|
|
|
cache_used = False
|
|
prompt_tokens, prompt_texts = [], []
|
|
for i, ref in enumerate(references):
|
|
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
|
|
# If the references are not already loaded, encode them
|
|
prompt_tokens.append(
|
|
self.encode_reference(
|
|
reference_audio=ref.audio,
|
|
enable_reference_audio=True,
|
|
)
|
|
)
|
|
prompt_texts.append(ref.text)
|
|
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
|
|
|
|
else:
|
|
# Reuse already encoded references
|
|
prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
|
|
cache_used = True
|
|
|
|
if cache_used:
|
|
logger.info("Use same references")
|
|
|
|
return prompt_tokens, prompt_texts
|
|
|
|
def load_audio(self, reference_audio, sr):
|
|
"""
|
|
Load the audio data from a file or bytes.
|
|
"""
|
|
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
audio_data = reference_audio
|
|
reference_audio = io.BytesIO(audio_data)
|
|
|
|
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
|
|
|
|
if waveform.shape[0] > 1:
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
|
|
if original_sr != sr:
|
|
resampler = torchaudio.transforms.Resample(
|
|
orig_freq=original_sr, new_freq=sr
|
|
)
|
|
waveform = resampler(waveform)
|
|
|
|
audio = waveform.squeeze().numpy()
|
|
return audio
|