import io import fastapi import pydantic import scipy import transformers import torch from config import * tokenizers = {} models = {} for lang in LanguageModel: tokenizers[lang.value] = transformers.AutoTokenizer.from_pretrained('./data/tokenizer/facebook/mms-tts-' + lang.name) models[lang.value] = transformers.VitsModel.from_pretrained('./data/model/facebook/mms-tts-' + lang.name) class SynthesizeRequest(pydantic.BaseModel): language: LanguageModel text: str class SynthesizeResponse(fastapi.Response): media_type = 'audio/wav' app = fastapi.FastAPI() @app.post('/synthesize', response_class=SynthesizeResponse) async def synthesize(request: SynthesizeRequest) -> SynthesizeResponse: inputs = tokenizers[request.language.value](request.text, return_tensors='pt') model = models[request.language.value] with torch.no_grad(): output = model(**inputs).waveform with io.BytesIO() as fp: scipy.io.wavfile.write(fp, rate=model.config.sampling_rate, data=output.float().numpy().T) return SynthesizeResponse(content = fp.getvalue())