2
0

voices_tab.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import json
  2. import gradio as gr
  3. import os
  4. import shutil
  5. from tts_webui.bark.history_to_hash import history_to_hash
  6. from tts_webui.history_tab.save_photo import save_photo
  7. from tts_webui.history_tab.edit_metadata_ui import edit_metadata_ui
  8. from tts_webui.bark.get_audio_from_npz import get_audio_from_full_generation
  9. from tts_webui.bark.npz_tools import load_npz, save_npz
  10. from tts_webui.history_tab.get_wav_files import get_npz_files_voices
  11. from tts_webui.history_tab.main import _get_filename, _get_row_index
  12. from tts_webui.history_tab.open_folder import open_folder
  13. from tts_webui.tortoise.gr_reload_button import gr_reload_button
  14. def update_voices_tab():
  15. return gr.List(value=get_npz_files_voices())
  16. def voices_tab(directory="voices"):
  17. with gr.Tab(directory.capitalize()) as voices_tab, gr.Row(equal_height=False):
  18. with gr.Column():
  19. with gr.Accordion("Gallery Selector (Click to Open)", open=False):
  20. history_list_as_gallery = gr.Gallery(
  21. value=[],
  22. columns=4,
  23. object_fit="contain",
  24. height="auto",
  25. )
  26. gr.Button(value="Refresh").click(
  27. fn=lambda: gr.Gallery(
  28. value=[
  29. f"voices/{x}"
  30. for x in os.listdir("voices")
  31. if x.endswith(".png")
  32. ]
  33. ),
  34. outputs=[history_list_as_gallery],
  35. )
  36. with gr.Row():
  37. button_output = gr.Button(value=f"Open {directory} folder")
  38. reload_button = gr_reload_button()
  39. button_output.click(lambda: open_folder(directory))
  40. datatypes = ["date", "str", "str", "str", "str"]
  41. headers = [
  42. "Date and Time",
  43. directory.capitalize(),
  44. "When",
  45. "Hash",
  46. "Filename",
  47. ]
  48. voices_list = gr.Dataframe(
  49. value=get_npz_files_voices(),
  50. interactive=False,
  51. datatype=datatypes,
  52. col_count=len(datatypes),
  53. headers=headers,
  54. max_height=800,
  55. # elem_classes="file-list"
  56. )
  57. with gr.Column():
  58. audio = gr.Audio(visible=True, type="numpy", label="Fine prompt audio")
  59. voice_hash = gr.Textbox(label="Hash", value="", interactive=False)
  60. crop_voice_button = gr.Button(value="Crop voice")
  61. voice_file_name = gr.Textbox(
  62. label="Voice file name", value="", interactive=False
  63. )
  64. new_voice_file_name = gr.Textbox(label="New voice file name", value="")
  65. with gr.Row():
  66. rename_voice_button = gr.Button(value="Rename voice")
  67. delete_voice_button = gr.Button(value="Delete voice", variant="stop")
  68. gr.Markdown("""Use voice button is now only available in React UI""")
  69. metadata = gr.JSON(label="Metadata")
  70. metadata_input = edit_metadata_ui(voice_file_name, metadata)
  71. photo = gr.Image(label="Photo", type="pil", interactive=True)
  72. file_list = gr.Files(value=[], label="Files", interactive=False)
  73. photo.upload(
  74. fn=save_photo,
  75. inputs=[photo, voice_file_name],
  76. outputs=[photo],
  77. )
  78. def delete_voice(voice_file_name):
  79. os.remove(voice_file_name)
  80. return {
  81. delete_voice_button: gr.Button(value="Deleted"),
  82. voices_list: update_voices_tab(),
  83. }
  84. def rename_voice(voice_file_name_in, new_voice_file_name):
  85. shutil.move(voice_file_name_in, new_voice_file_name)
  86. png_file = voice_file_name_in.replace(".npz", ".png")
  87. if os.path.exists(png_file):
  88. shutil.move(png_file, new_voice_file_name.replace(".npz", ".png"))
  89. return {
  90. rename_voice_button: gr.Button(value="Renamed"),
  91. voices_list: update_voices_tab(),
  92. voice_file_name: gr.Textbox(value=new_voice_file_name),
  93. }
  94. def crop_voice(voice_file_name, audio_in):
  95. from bark.generation import COARSE_RATE_HZ, SEMANTIC_RATE_HZ, N_COARSE_CODEBOOKS
  96. crop_min, crop_max = audio_in.get("crop_min", 0), audio_in.get("crop_max", 100)
  97. full_generation = load_npz(voice_file_name)
  98. semantic_prompt = full_generation["semantic_prompt"]
  99. len_semantic_prompt = len(semantic_prompt)
  100. semantic_prompt = semantic_prompt[
  101. len_semantic_prompt * crop_min // 100 : len_semantic_prompt
  102. * crop_max
  103. // 100
  104. ]
  105. coarse_prompt = full_generation["coarse_prompt"]
  106. len_coarse_prompt = coarse_prompt.shape[-1]
  107. coarse_prompt = coarse_prompt[
  108. :, len_coarse_prompt * crop_min // 100 : len_coarse_prompt * crop_max // 100
  109. ]
  110. fine_prompt = full_generation["fine_prompt"]
  111. len_fine_prompt = fine_prompt.shape[-1]
  112. fine_prompt = fine_prompt[
  113. :, len_fine_prompt * crop_min // 100 : len_fine_prompt * crop_max // 100
  114. ]
  115. semantic_to_coarse_ratio = (
  116. COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
  117. )
  118. assert round(coarse_prompt.shape[-1] / len(semantic_prompt), 1) == round(
  119. semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1
  120. )
  121. voice_file_name_cropped = voice_file_name.replace(
  122. ".npz", f"_cropped_{crop_min}_{crop_max}.npz"
  123. )
  124. new_hash = history_to_hash(full_generation) # type: ignore
  125. new_meta = full_generation.get("metadata", {})
  126. new_meta["crop_min"] = crop_min
  127. new_meta["crop_max"] = crop_max
  128. new_meta["hash"] = new_hash
  129. save_npz(
  130. voice_file_name_cropped,
  131. {
  132. "semantic_prompt": semantic_prompt,
  133. "coarse_prompt": coarse_prompt,
  134. "fine_prompt": fine_prompt,
  135. },
  136. metadata=new_meta,
  137. )
  138. return select_filename(voice_file_name_cropped)
  139. rename_voice_button.click(
  140. fn=rename_voice,
  141. inputs=[voice_file_name, new_voice_file_name],
  142. outputs=[rename_voice_button, voices_list, voice_file_name],
  143. )
  144. delete_voice_button.click(
  145. fn=delete_voice,
  146. inputs=[voice_file_name],
  147. outputs=[delete_voice_button, voices_list],
  148. )
  149. def select_filename(filename_npz):
  150. full_generation = load_npz(filename_npz)
  151. resolved_photo = filename_npz.replace(".npz", ".png")
  152. if not os.path.exists(resolved_photo):
  153. resolved_photo = None
  154. return {
  155. voice_file_name: gr.Textbox(value=filename_npz),
  156. new_voice_file_name: gr.Textbox(value=filename_npz),
  157. delete_voice_button: gr.Button(value="Delete"),
  158. rename_voice_button: gr.Button(value="Rename"),
  159. audio: gr.Audio(value=get_audio_from_full_generation(full_generation)), # type: ignore
  160. metadata: gr.JSON(value=full_generation.get("metadata", {})),
  161. metadata_input: gr.Textbox(
  162. value=json.dumps(full_generation.get("metadata", {}), indent=2)
  163. ),
  164. photo: gr.Image(value=resolved_photo),
  165. voice_hash: gr.Textbox(value=history_to_hash(full_generation)), # type: ignore
  166. file_list: gr.Files(
  167. value=get_file_list(filename_npz, resolved_photo),
  168. label="Files",
  169. ),
  170. }
  171. def get_file_list(filename_npz, resolved_photo):
  172. if resolved_photo is None:
  173. return [filename_npz]
  174. return [filename_npz, resolved_photo]
  175. def select(_list_data, evt: gr.SelectData):
  176. filename_npz = _get_filename(_list_data, _get_row_index(evt))
  177. return select_filename(filename_npz)
  178. outputs = [
  179. voice_file_name,
  180. new_voice_file_name,
  181. delete_voice_button,
  182. rename_voice_button,
  183. audio,
  184. metadata,
  185. metadata_input,
  186. photo,
  187. voice_hash,
  188. file_list,
  189. ]
  190. crop_voice_button.click(
  191. fn=crop_voice,
  192. inputs=[voice_file_name, audio],
  193. outputs=outputs,
  194. preprocess=False,
  195. ).then(
  196. fn=update_voices_tab,
  197. outputs=[voices_list],
  198. )
  199. reload_button.click(fn=update_voices_tab, outputs=[voices_list])
  200. voices_list.select(
  201. fn=select, inputs=[voices_list], outputs=outputs, preprocess=False
  202. )
  203. def select_gallery(_list_data, evt: gr.SelectData):
  204. def get_gallery_file_selection(_gallery_data, evt: gr.SelectData):
  205. selected_image = _gallery_data[evt.index]
  206. image_path = selected_image["name"]
  207. import os
  208. image_name = os.path.basename(image_path)
  209. return image_name.replace(".png", "")
  210. filename_base = get_gallery_file_selection(_list_data, evt)
  211. return select_filename(f"voices/{filename_base}.npz")
  212. history_list_as_gallery.select(
  213. fn=select_gallery, inputs=[history_list_as_gallery], outputs=outputs
  214. )
  215. voices_tab.select(fn=update_voices_tab, outputs=[voices_list])