Some checks failed
Close inactive issues / close-issues (push) Has been cancelled
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
from argparse import ArgumentParser
|
|
from http import HTTPStatus
|
|
from typing import Annotated, Any
|
|
|
|
import ormsgpack
|
|
from baize.datastructures import ContentType
|
|
from kui.asgi import HTTPException, HttpRequest
|
|
|
|
from fish_speech.inference_engine import TTSInferenceEngine
|
|
from fish_speech.utils.schema import ServeTTSRequest
|
|
from tools.server.inference import inference_wrapper as inference
|
|
|
|
|
|
def parse_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
|
parser.add_argument("--load-asr-model", action="store_true")
|
|
parser.add_argument(
|
|
"--llama-checkpoint-path",
|
|
type=str,
|
|
default="checkpoints/fish-speech-1.5",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-checkpoint-path",
|
|
type=str,
|
|
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
)
|
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
parser.add_argument("--half", action="store_true")
|
|
parser.add_argument("--compile", action="store_true")
|
|
parser.add_argument("--max-text-length", type=int, default=0)
|
|
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
|
parser.add_argument("--workers", type=int, default=1)
|
|
parser.add_argument("--api-key", type=str, default=None)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
class MsgPackRequest(HttpRequest):
|
|
async def data(
|
|
self,
|
|
) -> Annotated[
|
|
Any, ContentType("application/msgpack"), ContentType("application/json")
|
|
]:
|
|
if self.content_type == "application/msgpack":
|
|
return ormsgpack.unpackb(await self.body)
|
|
|
|
elif self.content_type == "application/json":
|
|
return await self.json
|
|
|
|
raise HTTPException(
|
|
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
|
headers={"Accept": "application/msgpack, application/json"},
|
|
)
|
|
|
|
|
|
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
|
|
for chunk in inference(req, engine):
|
|
if isinstance(chunk, bytes):
|
|
yield chunk
|
|
|
|
|
|
async def buffer_to_async_generator(buffer):
|
|
yield buffer
|
|
|
|
|
|
def get_content_type(audio_format):
|
|
if audio_format == "wav":
|
|
return "audio/wav"
|
|
elif audio_format == "flac":
|
|
return "audio/flac"
|
|
elif audio_format == "mp3":
|
|
return "audio/mpeg"
|
|
else:
|
|
return "application/octet-stream"
|