Some checks failed
Close inactive issues / close-issues (push) Has been cancelled
196 lines
6.3 KiB
Python
196 lines
6.3 KiB
Python
import gc
|
|
import queue
|
|
from typing import Generator
|
|
|
|
import numpy as np
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from fish_speech.inference_engine.reference_loader import ReferenceLoader
|
|
from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
|
|
from fish_speech.inference_engine.vq_manager import VQManager
|
|
from fish_speech.models.text2semantic.inference import (
|
|
GenerateRequest,
|
|
GenerateResponse,
|
|
WrappedGenerateResponse,
|
|
)
|
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
from fish_speech.utils.schema import ServeTTSRequest
|
|
|
|
|
|
class TTSInferenceEngine(ReferenceLoader, VQManager):
|
|
|
|
def __init__(
|
|
self,
|
|
llama_queue: queue.Queue,
|
|
decoder_model: FireflyArchitecture,
|
|
precision: torch.dtype,
|
|
compile: bool,
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.llama_queue = llama_queue
|
|
self.decoder_model = decoder_model
|
|
self.precision = precision
|
|
self.compile = compile
|
|
|
|
@torch.inference_mode()
|
|
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
|
|
"""
|
|
Main inference function:
|
|
- Loads the reference audio and text.
|
|
- Calls the LLAMA model for inference.
|
|
- Decodes the VQ tokens to audio.
|
|
"""
|
|
|
|
ref_id: str | None = req.reference_id
|
|
prompt_tokens, prompt_texts = [], []
|
|
# Load the reference audio and text based on id or hash
|
|
if ref_id is not None:
|
|
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
|
|
|
|
elif req.references:
|
|
prompt_tokens, prompt_texts = self.load_by_hash(
|
|
req.references, req.use_memory_cache
|
|
)
|
|
|
|
# Set the random seed if provided
|
|
if req.seed is not None:
|
|
set_seed(req.seed)
|
|
logger.warning(f"set seed: {req.seed}")
|
|
|
|
# Get the symbolic tokens from the LLAMA model
|
|
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
|
|
|
|
# Get the sample rate from the decoder model
|
|
sample_rate = self.decoder_model.spec_transform.sample_rate
|
|
|
|
# If streaming, send the header
|
|
if req.streaming:
|
|
yield InferenceResult(
|
|
code="header",
|
|
audio=(
|
|
sample_rate,
|
|
np.array(wav_chunk_header(sample_rate=sample_rate)),
|
|
),
|
|
error=None,
|
|
)
|
|
|
|
segments = []
|
|
|
|
while True:
|
|
# Get the response from the LLAMA model
|
|
wrapped_result: WrappedGenerateResponse = response_queue.get()
|
|
if wrapped_result.status == "error":
|
|
yield InferenceResult(
|
|
code="error",
|
|
audio=None,
|
|
error=(
|
|
wrapped_result.response
|
|
if isinstance(wrapped_result.response, Exception)
|
|
else Exception("Unknown error")
|
|
),
|
|
)
|
|
break
|
|
|
|
# Check the response type
|
|
if not isinstance(wrapped_result.response, GenerateResponse):
|
|
raise TypeError(
|
|
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
|
|
)
|
|
|
|
result: GenerateResponse = wrapped_result.response
|
|
if result.action != "next":
|
|
segment = self.get_audio_segment(result)
|
|
|
|
if req.streaming: # Used only by the API server
|
|
yield InferenceResult(
|
|
code="segment",
|
|
audio=(sample_rate, segment),
|
|
error=None,
|
|
)
|
|
segments.append(segment)
|
|
else:
|
|
break
|
|
|
|
# Clean up the memory
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Edge case: no audio generated
|
|
if len(segments) == 0:
|
|
yield InferenceResult(
|
|
code="error",
|
|
audio=None,
|
|
error=RuntimeError("No audio generated, please check the input text."),
|
|
)
|
|
else:
|
|
# Streaming or not, return the final audio
|
|
audio = np.concatenate(segments, axis=0)
|
|
yield InferenceResult(
|
|
code="final",
|
|
audio=(sample_rate, audio),
|
|
error=None,
|
|
)
|
|
|
|
return None
|
|
|
|
def send_Llama_request(
|
|
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
|
|
) -> queue.Queue:
|
|
"""
|
|
Send a request to the LLAMA model to generate the symbolic tokens.
|
|
"""
|
|
|
|
# Prepare the request
|
|
request = dict(
|
|
device=self.decoder_model.device,
|
|
max_new_tokens=req.max_new_tokens,
|
|
text=(
|
|
req.text
|
|
if not req.normalize
|
|
else ChnNormedText(raw_text=req.text).normalize()
|
|
),
|
|
top_p=req.top_p,
|
|
repetition_penalty=req.repetition_penalty,
|
|
temperature=req.temperature,
|
|
compile=self.compile,
|
|
iterative_prompt=req.chunk_length > 0,
|
|
chunk_length=req.chunk_length,
|
|
max_length=4096,
|
|
prompt_tokens=prompt_tokens,
|
|
prompt_text=prompt_texts,
|
|
)
|
|
|
|
# Create a queue to get the response
|
|
response_queue = queue.Queue()
|
|
|
|
# Send the request to the LLAMA model
|
|
self.llama_queue.put(
|
|
GenerateRequest(
|
|
request=request,
|
|
response_queue=response_queue,
|
|
)
|
|
)
|
|
|
|
return response_queue
|
|
|
|
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
|
|
"""
|
|
Decode the VQ tokens to audio.
|
|
"""
|
|
|
|
# Don't use autocast on MPS devices
|
|
with autocast_exclude_mps(
|
|
device_type=self.decoder_model.device.type, dtype=self.precision
|
|
):
|
|
# Decode the symbolic tokens to audio
|
|
segment = self.decode_vq_tokens(codes=result.codes)
|
|
|
|
# Convert the audio to numpy
|
|
return segment.float().cpu().numpy()
|