Skip to content

生成字幕 #1658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions GPT_SoVITS/TTS_infer_pack/TTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def to_batch(self, data:list,
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
origin_text_batch = []
all_bert_max_len = 0
all_phones_max_len = 0
for item in item_list:
Expand All @@ -575,6 +576,7 @@ def to_batch(self, data:list,
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
origin_text_batch.append(item["origin_text"])

phones_batch = phones_list
all_phones_batch = all_phones_list
Expand Down Expand Up @@ -606,6 +608,7 @@ def to_batch(self, data:list,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch,
"origin_text": origin_text_batch,
"max_len": max_len,
}
_data.append(batch)
Expand Down Expand Up @@ -658,6 +661,7 @@ def run(self, inputs:dict):
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"return_with_srt": "", # str. return with or without("") subtitles, using "orig"inal or "norm"alized text
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
Expand Down Expand Up @@ -685,6 +689,7 @@ def run(self, inputs:dict):
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
fragment_interval = inputs.get("fragment_interval", 0.3)
return_with_srt = inputs.get("return_with_srt", "")
seed = inputs.get("seed", -1)
seed = -1 if seed in ["", None] else seed
actual_seed = set_seed(seed)
Expand All @@ -704,6 +709,9 @@ def run(self, inputs:dict):
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))

ret_width = 3 if return_with_srt else 2 # return (sr, audio, srt) or (sr, audio)
srt_text = "norm_text" if return_with_srt.startswith("norm") else "origin_text"

if split_bucket and speed_factor==1.0:
print(i18n("分桶处理模式已开启"))
elif speed_factor!=1.0:
Expand Down Expand Up @@ -773,8 +781,7 @@ def run(self, inputs:dict):
if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield self.audio_failure()[:ret_width]
return

batch_index_list:list = None
Expand Down Expand Up @@ -806,6 +813,7 @@ def make_batch(batch_texts):
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
"origin_text": text,
}
batch_data.append(res)
if len(batch_data) == 0:
Expand Down Expand Up @@ -841,10 +849,11 @@ def make_batch(batch_texts):
all_phoneme_ids:torch.LongTensor = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
all_bert_features:torch.LongTensor = item["all_bert_features"]
norm_text:str = item["norm_text"]
# norm_text:List[str] = item["norm_text"]
# origin_text:List[str] = item["origin_text"]
max_len = item["max_len"]

print(i18n("前端处理后的文本(每句):"), norm_text)
print(i18n("前端处理后的文本(每批):"), item["norm_text"])
if no_prompt_text :
prompt = None
else:
Expand Down Expand Up @@ -915,39 +924,38 @@ def make_batch(batch_texts):
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment],
[item[srt_text]],
self.configs.sampling_rate,
None,
speed_factor,
False,
fragment_interval
)
)[:ret_width]
else:
audio.append(batch_audio_fragment)

if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield self.audio_failure()[:ret_width]
return

if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield self.audio_failure()[:ret_width]
return
yield self.audio_postprocess(audio,
[v[srt_text] for v in data],
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket,
fragment_interval
)
)[:ret_width]

except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield self.audio_failure()[:ret_width]
# 重置模型, 否则会导致显存释放不完全。
del self.t2s_model
del self.vits_model
Expand All @@ -968,15 +976,19 @@ def empty_cache(self):
torch.mps.empty_cache()
except:
pass


def audio_failure(self):
return self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16), []

def audio_postprocess(self,
audio:List[torch.Tensor],
audio:List[torch.Tensor],
texts:List[List[str]],
sr:int,
batch_index_list:list=None,
speed_factor:float=1.0,
split_bucket:bool=True,
fragment_interval:float=0.3
)->Tuple[int, np.ndarray]:
)->Tuple[int, np.ndarray, List]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval),
dtype=self.precision,
Expand All @@ -993,11 +1005,17 @@ def audio_postprocess(self,

if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
texts = self.recovery_order(texts, batch_index_list)
else:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])


texts = sum(texts, [])

# 按顺序计算每段语音的起止时间,并与文字一一对应,用于生成字幕
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后处理计算音频时间和恢复顺序这边,不需要返回字幕的话不去计算应该好一点,就是用单独的逻辑去控制是否需要计算。

from itertools import accumulate
stamps = [0.0] + [x/sr for x in accumulate([v.size for v in audio])]
srts = list(zip(stamps[:-1], stamps[1:], texts)) # time start, end, text

audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)

Expand All @@ -1007,7 +1025,7 @@ def audio_postprocess(self,
# except Exception as e:
# print(f"Failed to change speed of audio: \n{e}")

return sr, audio
return sr, audio, srts



Expand Down
1 change: 1 addition & 0 deletions GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2"
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
"origin_text": text,
}
result.append(res)
return result
Expand Down
60 changes: 56 additions & 4 deletions api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response.
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
Expand Down Expand Up @@ -98,7 +99,7 @@
import os
import sys
import traceback
from typing import Generator
from typing import Generator, List, Union

now_dir = os.getcwd()
sys.path.append(now_dir)
Expand Down Expand Up @@ -162,6 +163,7 @@ class TTS_Request(BaseModel):
seed:int = -1
media_type:str = "wav"
streaming_mode:bool = False
with_srt_format:str = ""
parallel_infer:bool = True
repetition_penalty:float = 1.35

Expand Down Expand Up @@ -211,7 +213,38 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
io_buffer.seek(0)
return io_buffer


def pack_srt(srt:List, fmt:str):
if fmt == "raw":
return srt
# TODO: support formats like "srt", "lrc", "vtt", ...
return srt

def load_base64_audio(audio):
import base64
if isinstance(audio, (bytes, bytearray)):
audio = bytes(audio)
elif hasattr(audio, 'read'): # file-like obj
audio = audio.read()
else: # path-like
audio = open(audio, 'rb').read()
return base64.b64encode(audio).decode('ascii')

_base64_audio_cache = {}
def save_base64_audio(b64str:str):
import filetype, base64, uuid
global _base64_audio_cache
if b64str in _base64_audio_cache:
return _base64_audio_cache[b64str]
savedir = 'TEMP/upload'
data = base64.b64decode(b64str)
ft = filetype.guess(data)
ext = f'.{ft.extension}' if ft else ''
os.makedirs(savedir, exist_ok=True)
saveto = f'{savedir}/{uuid.uuid1()}{ext}'
with open(saveto, 'wb') as outf:
outf.write(data)
_base64_audio_cache[b64str] = saveto
return saveto

# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
Expand Down Expand Up @@ -277,7 +310,7 @@ async def tts_handle(req:dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"ref_audio_path": "", # str.(required) reference audio path ; allow data of format base64:xxxxxx
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
Expand All @@ -293,6 +326,7 @@ async def tts_handle(req:dict):
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet)
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
}
Expand All @@ -303,14 +337,21 @@ async def tts_handle(req:dict):
streaming_mode = req.get("streaming_mode", False)
return_fragment = req.get("return_fragment", False)
media_type = req.get("media_type", "wav")
with_srt_format = req.get("with_srt_format", "")
ref_audio_path = req.get("ref_audio_path", "")
if ref_audio_path.startswith("base64:"):
req['ref_audio_path'] = ref_audio_path = save_base64_audio(ref_audio_path[len("base64:"):])

check_res = check_params(req)
if check_res is not None:
return check_res

if streaming_mode or return_fragment:
req["return_fragment"] = True


if streaming_mode: with_srt_format = "" # streaming not support srt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

流式不支持字幕时最好log输出一下,提醒用户。

req["return_with_srt"] = "orig" if with_srt_format else ""

try:
tts_generator=tts_pipeline.run(req)

Expand All @@ -324,6 +365,16 @@ def streaming_generator(tts_generator:Generator, media_type:str):
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")

elif with_srt_format:
output = []
for sr, audio_data, srt_data in tts_generator:
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
output.append({
"audio": load_base64_audio(audio_data), "media_type": f"audio/{media_type}",
"srt": pack_srt(srt_data, with_srt_format), "srt_fmt": with_srt_format,
})
return { "message":"succeed", "output":output } # Jsonresponse(status_code=200, content=...)

else:
sr, audio_data = next(tts_generator)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
Expand Down Expand Up @@ -364,6 +415,7 @@ async def tts_get_endpoint(
seed:int = -1,
media_type:str = "wav",
streaming_mode:bool = False,
with_srt_format:str = "",
parallel_infer:bool = True,
repetition_penalty:float = 1.35
):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ opencc; sys_platform != 'linux'
opencc==1.1.1; sys_platform == 'linux'
python_mecab_ko; sys_platform != 'win32'
fastapi<0.112.2
filetype