commit | author | age
|
71c082
|
1 |
import io |
JK |
2 |
|
|
3 |
import fastapi |
|
4 |
import pydantic |
|
5 |
import scipy |
|
6 |
import transformers |
|
7 |
import torch |
|
8 |
|
|
9 |
from config import * |
|
10 |
|
|
11 |
tokenizers = {} |
|
12 |
models = {} |
|
13 |
|
|
14 |
for lang in LanguageModel: |
|
15 |
tokenizers[lang.value] = transformers.AutoTokenizer.from_pretrained('./data/tokenizer/facebook/mms-tts-' + lang.name) |
|
16 |
models[lang.value] = transformers.VitsModel.from_pretrained('./data/model/facebook/mms-tts-' + lang.name) |
|
17 |
|
|
18 |
|
|
19 |
class SynthesizeRequest(pydantic.BaseModel): |
|
20 |
language: LanguageModel |
|
21 |
text: str |
|
22 |
|
|
23 |
class SynthesizeResponse(fastapi.Response): |
|
24 |
media_type = 'audio/wav' |
|
25 |
|
|
26 |
|
|
27 |
app = fastapi.FastAPI() |
|
28 |
|
|
29 |
|
|
30 |
@app.post('/synthesize', response_class=SynthesizeResponse) |
|
31 |
async def synthesize(request: SynthesizeRequest) -> SynthesizeResponse: |
|
32 |
inputs = tokenizers[request.language.value](request.text, return_tensors='pt') |
|
33 |
model = models[request.language.value] |
|
34 |
with torch.no_grad(): |
|
35 |
output = model(**inputs).waveform |
|
36 |
|
|
37 |
with io.BytesIO() as fp: |
|
38 |
scipy.io.wavfile.write(fp, rate=model.config.sampling_rate, data=output.float().numpy().T) |
|
39 |
return SynthesizeResponse(content = fp.getvalue()) |