slicer2.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import numpy as np
  2. # This function is obtained from librosa.
  3. def get_rms(
  4. y,
  5. frame_length=2048,
  6. hop_length=512,
  7. pad_mode="constant",
  8. ):
  9. padding = (int(frame_length // 2), int(frame_length // 2))
  10. y = np.pad(y, padding, mode=pad_mode)
  11. axis = -1
  12. # put our new within-frame axis at the end for now
  13. out_strides = y.strides + tuple([y.strides[axis]])
  14. # Reduce the shape on the framing axis
  15. x_shape_trimmed = list(y.shape)
  16. x_shape_trimmed[axis] -= frame_length - 1
  17. out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
  18. xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
  19. if axis < 0:
  20. target_axis = axis - 1
  21. else:
  22. target_axis = axis + 1
  23. xw = np.moveaxis(xw, -1, target_axis)
  24. # Downsample along the target axis
  25. slices = [slice(None)] * xw.ndim
  26. slices[axis] = slice(0, None, hop_length)
  27. x = xw[tuple(slices)]
  28. # Calculate power
  29. power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
  30. return np.sqrt(power)
  31. class Slicer:
  32. def __init__(
  33. self,
  34. sr: int,
  35. threshold: float = -40.0,
  36. min_length: int = 5000,
  37. min_interval: int = 300,
  38. hop_size: int = 20,
  39. max_sil_kept: int = 5000,
  40. ):
  41. if not min_length >= min_interval >= hop_size:
  42. raise ValueError(
  43. "The following condition must be satisfied: min_length >= min_interval >= hop_size"
  44. )
  45. if not max_sil_kept >= hop_size:
  46. raise ValueError(
  47. "The following condition must be satisfied: max_sil_kept >= hop_size"
  48. )
  49. min_interval = sr * min_interval / 1000
  50. self.threshold = 10 ** (threshold / 20.0)
  51. self.hop_size = round(sr * hop_size / 1000)
  52. self.win_size = min(round(min_interval), 4 * self.hop_size)
  53. self.min_length = round(sr * min_length / 1000 / self.hop_size)
  54. self.min_interval = round(min_interval / self.hop_size)
  55. self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
  56. def _apply_slice(self, waveform, begin, end):
  57. if len(waveform.shape) > 1:
  58. return waveform[
  59. :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
  60. ]
  61. else:
  62. return waveform[
  63. begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
  64. ]
  65. # @timeit
  66. def slice(self, waveform):
  67. if len(waveform.shape) > 1:
  68. samples = waveform.mean(axis=0)
  69. else:
  70. samples = waveform
  71. if samples.shape[0] <= self.min_length:
  72. return [waveform]
  73. rms_list = get_rms(
  74. y=samples, frame_length=self.win_size, hop_length=self.hop_size
  75. ).squeeze(0)
  76. sil_tags = []
  77. silence_start = None
  78. clip_start = 0
  79. for i, rms in enumerate(rms_list):
  80. # Keep looping while frame is silent.
  81. if rms < self.threshold:
  82. # Record start of silent frames.
  83. if silence_start is None:
  84. silence_start = i
  85. continue
  86. # Keep looping while frame is not silent and silence start has not been recorded.
  87. if silence_start is None:
  88. continue
  89. # Clear recorded silence start if interval is not enough or clip is too short
  90. is_leading_silence = silence_start == 0 and i > self.max_sil_kept
  91. need_slice_middle = (
  92. i - silence_start >= self.min_interval
  93. and i - clip_start >= self.min_length
  94. )
  95. if not is_leading_silence and not need_slice_middle:
  96. silence_start = None
  97. continue
  98. # Need slicing. Record the range of silent frames to be removed.
  99. if i - silence_start <= self.max_sil_kept:
  100. pos = rms_list[silence_start : i + 1].argmin() + silence_start
  101. if silence_start == 0:
  102. sil_tags.append((0, pos))
  103. else:
  104. sil_tags.append((pos, pos))
  105. clip_start = pos
  106. elif i - silence_start <= self.max_sil_kept * 2:
  107. pos = rms_list[
  108. i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
  109. ].argmin()
  110. pos += i - self.max_sil_kept
  111. pos_l = (
  112. rms_list[
  113. silence_start : silence_start + self.max_sil_kept + 1
  114. ].argmin()
  115. + silence_start
  116. )
  117. pos_r = (
  118. rms_list[i - self.max_sil_kept : i + 1].argmin()
  119. + i
  120. - self.max_sil_kept
  121. )
  122. if silence_start == 0:
  123. sil_tags.append((0, pos_r))
  124. clip_start = pos_r
  125. else:
  126. sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
  127. clip_start = max(pos_r, pos)
  128. else:
  129. pos_l = (
  130. rms_list[
  131. silence_start : silence_start + self.max_sil_kept + 1
  132. ].argmin()
  133. + silence_start
  134. )
  135. pos_r = (
  136. rms_list[i - self.max_sil_kept : i + 1].argmin()
  137. + i
  138. - self.max_sil_kept
  139. )
  140. if silence_start == 0:
  141. sil_tags.append((0, pos_r))
  142. else:
  143. sil_tags.append((pos_l, pos_r))
  144. clip_start = pos_r
  145. silence_start = None
  146. # Deal with trailing silence.
  147. total_frames = rms_list.shape[0]
  148. if (
  149. silence_start is not None
  150. and total_frames - silence_start >= self.min_interval
  151. ):
  152. silence_end = min(total_frames, silence_start + self.max_sil_kept)
  153. pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
  154. sil_tags.append((pos, total_frames + 1))
  155. # Apply and return slices.
  156. ####音频+起始时间+终止时间
  157. if len(sil_tags) == 0:
  158. return [[waveform,0,int(total_frames*self.hop_size)]]
  159. else:
  160. chunks = []
  161. if sil_tags[0][0] > 0:
  162. chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
  163. for i in range(len(sil_tags) - 1):
  164. chunks.append(
  165. [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)]
  166. )
  167. if sil_tags[-1][1] < total_frames:
  168. chunks.append(
  169. [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)]
  170. )
  171. return chunks
  172. def main():
  173. import os.path
  174. from argparse import ArgumentParser
  175. import librosa
  176. import soundfile
  177. parser = ArgumentParser()
  178. parser.add_argument("audio", type=str, help="The audio to be sliced")
  179. parser.add_argument(
  180. "--out", type=str, help="Output directory of the sliced audio clips"
  181. )
  182. parser.add_argument(
  183. "--db_thresh",
  184. type=float,
  185. required=False,
  186. default=-40,
  187. help="The dB threshold for silence detection",
  188. )
  189. parser.add_argument(
  190. "--min_length",
  191. type=int,
  192. required=False,
  193. default=5000,
  194. help="The minimum milliseconds required for each sliced audio clip",
  195. )
  196. parser.add_argument(
  197. "--min_interval",
  198. type=int,
  199. required=False,
  200. default=300,
  201. help="The minimum milliseconds for a silence part to be sliced",
  202. )
  203. parser.add_argument(
  204. "--hop_size",
  205. type=int,
  206. required=False,
  207. default=10,
  208. help="Frame length in milliseconds",
  209. )
  210. parser.add_argument(
  211. "--max_sil_kept",
  212. type=int,
  213. required=False,
  214. default=500,
  215. help="The maximum silence length kept around the sliced clip, presented in milliseconds",
  216. )
  217. args = parser.parse_args()
  218. out = args.out
  219. if out is None:
  220. out = os.path.dirname(os.path.abspath(args.audio))
  221. audio, sr = librosa.load(args.audio, sr=None, mono=False)
  222. slicer = Slicer(
  223. sr=sr,
  224. threshold=args.db_thresh,
  225. min_length=args.min_length,
  226. min_interval=args.min_interval,
  227. hop_size=args.hop_size,
  228. max_sil_kept=args.max_sil_kept,
  229. )
  230. chunks = slicer.slice(audio)
  231. if not os.path.exists(out):
  232. os.makedirs(out)
  233. for i, chunk in enumerate(chunks):
  234. if len(chunk.shape) > 1:
  235. chunk = chunk.T
  236. soundfile.write(
  237. os.path.join(
  238. out,
  239. f"%s_%d.wav"
  240. % (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
  241. ),
  242. chunk,
  243. sr,
  244. )
  245. if __name__ == "__main__":
  246. main()