logmmse.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # The MIT License (MIT)
  2. #
  3. # Copyright (c) 2015 braindead
  4. #
  5. # Permission is hereby granted, free of charge, to any person obtaining a copy
  6. # of this software and associated documentation files (the "Software"), to deal
  7. # in the Software without restriction, including without limitation the rights
  8. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. # copies of the Software, and to permit persons to whom the Software is
  10. # furnished to do so, subject to the following conditions:
  11. #
  12. # The above copyright notice and this permission notice shall be included in all
  13. # copies or substantial portions of the Software.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. # SOFTWARE.
  22. #
  23. #
  24. # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
  25. # simply modified the interface to meet my needs.
  26. import numpy as np
  27. import math
  28. from scipy.special import expn
  29. from collections import namedtuple
  30. NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
  31. def profile_noise(noise, sampling_rate, window_size=0):
  32. """
  33. Creates a profile of the noise in a given waveform.
  34. :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
  35. :param sampling_rate: the sampling rate of the audio
  36. :param window_size: the size of the window the logmmse algorithm operates on. A default value
  37. will be picked if left as 0.
  38. :return: a NoiseProfile object
  39. """
  40. noise, dtype = to_float(noise)
  41. noise += np.finfo(np.float64).eps
  42. if window_size == 0:
  43. window_size = int(math.floor(0.02 * sampling_rate))
  44. if window_size % 2 == 1:
  45. window_size = window_size + 1
  46. perc = 50
  47. len1 = int(math.floor(window_size * perc / 100))
  48. len2 = int(window_size - len1)
  49. win = np.hanning(window_size)
  50. win = win * len2 / np.sum(win)
  51. n_fft = 2 * window_size
  52. noise_mean = np.zeros(n_fft)
  53. n_frames = len(noise) // window_size
  54. for j in range(0, window_size * n_frames, window_size):
  55. noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
  56. noise_mu2 = (noise_mean / n_frames) ** 2
  57. return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
  58. def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
  59. """
  60. Cleans the noise from a speech waveform given a noise profile. The waveform must have the
  61. same sampling rate as the one used to create the noise profile.
  62. :param wav: a speech waveform as a numpy array of floats or ints.
  63. :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
  64. the same) waveform.
  65. :param eta: voice threshold for noise update. While the voice activation detection value is
  66. below this threshold, the noise profile will be continuously updated throughout the audio.
  67. Set to 0 to disable updating the noise profile.
  68. :return: the clean wav as a numpy array of floats or ints of the same length.
  69. """
  70. wav, dtype = to_float(wav)
  71. wav += np.finfo(np.float64).eps
  72. p = noise_profile
  73. nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
  74. x_final = np.zeros(nframes * p.len2)
  75. aa = 0.98
  76. mu = 0.98
  77. ksi_min = 10 ** (-25 / 10)
  78. x_old = np.zeros(p.len1)
  79. xk_prev = np.zeros(p.len1)
  80. noise_mu2 = p.noise_mu2
  81. for k in range(0, nframes * p.len2, p.len2):
  82. insign = p.win * wav[k:k + p.window_size]
  83. spec = np.fft.fft(insign, p.n_fft, axis=0)
  84. sig = np.absolute(spec)
  85. sig2 = sig ** 2
  86. gammak = np.minimum(sig2 / noise_mu2, 40)
  87. if xk_prev.all() == 0:
  88. ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
  89. else:
  90. ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
  91. ksi = np.maximum(ksi_min, ksi)
  92. log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
  93. vad_decision = np.sum(log_sigma_k) / p.window_size
  94. if vad_decision < eta:
  95. noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
  96. a = ksi / (1 + ksi)
  97. vk = a * gammak
  98. ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
  99. hw = a * np.exp(ei_vk)
  100. sig = sig * hw
  101. xk_prev = sig ** 2
  102. xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
  103. xi_w = np.real(xi_w)
  104. x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
  105. x_old = xi_w[p.len1:p.window_size]
  106. output = from_float(x_final, dtype)
  107. output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
  108. return output
  109. ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
  110. ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
  111. ## webrctvad
  112. # def vad(wav, sampling_rate, eta=0.15, window_size=0):
  113. # """
  114. # TODO: fix doc
  115. # Creates a profile of the noise in a given waveform.
  116. #
  117. # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
  118. # :param sampling_rate: the sampling rate of the audio
  119. # :param window_size: the size of the window the logmmse algorithm operates on. A default value
  120. # will be picked if left as 0.
  121. # :param eta: voice threshold for noise update. While the voice activation detection value is
  122. # below this threshold, the noise profile will be continuously updated throughout the audio.
  123. # Set to 0 to disable updating the noise profile.
  124. # """
  125. # wav, dtype = to_float(wav)
  126. # wav += np.finfo(np.float64).eps
  127. #
  128. # if window_size == 0:
  129. # window_size = int(math.floor(0.02 * sampling_rate))
  130. #
  131. # if window_size % 2 == 1:
  132. # window_size = window_size + 1
  133. #
  134. # perc = 50
  135. # len1 = int(math.floor(window_size * perc / 100))
  136. # len2 = int(window_size - len1)
  137. #
  138. # win = np.hanning(window_size)
  139. # win = win * len2 / np.sum(win)
  140. # n_fft = 2 * window_size
  141. #
  142. # wav_mean = np.zeros(n_fft)
  143. # n_frames = len(wav) // window_size
  144. # for j in range(0, window_size * n_frames, window_size):
  145. # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
  146. # noise_mu2 = (wav_mean / n_frames) ** 2
  147. #
  148. # wav, dtype = to_float(wav)
  149. # wav += np.finfo(np.float64).eps
  150. #
  151. # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
  152. # vad = np.zeros(nframes * len2, dtype=np.bool)
  153. #
  154. # aa = 0.98
  155. # mu = 0.98
  156. # ksi_min = 10 ** (-25 / 10)
  157. #
  158. # xk_prev = np.zeros(len1)
  159. # noise_mu2 = noise_mu2
  160. # for k in range(0, nframes * len2, len2):
  161. # insign = win * wav[k:k + window_size]
  162. #
  163. # spec = np.fft.fft(insign, n_fft, axis=0)
  164. # sig = np.absolute(spec)
  165. # sig2 = sig ** 2
  166. #
  167. # gammak = np.minimum(sig2 / noise_mu2, 40)
  168. #
  169. # if xk_prev.all() == 0:
  170. # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
  171. # else:
  172. # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
  173. # ksi = np.maximum(ksi_min, ksi)
  174. #
  175. # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
  176. # vad_decision = np.sum(log_sigma_k) / window_size
  177. # if vad_decision < eta:
  178. # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
  179. # print(vad_decision)
  180. #
  181. # a = ksi / (1 + ksi)
  182. # vk = a * gammak
  183. # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
  184. # hw = a * np.exp(ei_vk)
  185. # sig = sig * hw
  186. # xk_prev = sig ** 2
  187. #
  188. # vad[k:k + len2] = vad_decision >= eta
  189. #
  190. # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
  191. # return vad
  192. def to_float(_input):
  193. if _input.dtype == np.float64:
  194. return _input, _input.dtype
  195. elif _input.dtype == np.float32:
  196. return _input.astype(np.float64), _input.dtype
  197. elif _input.dtype == np.uint8:
  198. return (_input - 128) / 128., _input.dtype
  199. elif _input.dtype == np.int16:
  200. return _input / 32768., _input.dtype
  201. elif _input.dtype == np.int32:
  202. return _input / 2147483648., _input.dtype
  203. raise ValueError('Unsupported wave file format')
  204. def from_float(_input, dtype):
  205. if dtype == np.float64:
  206. return _input, np.float64
  207. elif dtype == np.float32:
  208. return _input.astype(np.float32)
  209. elif dtype == np.uint8:
  210. return ((_input * 128) + 128).astype(np.uint8)
  211. elif dtype == np.int16:
  212. return (_input * 32768).astype(np.int16)
  213. elif dtype == np.int32:
  214. print(_input)
  215. return (_input * 2147483648).astype(np.int32)
  216. raise ValueError('Unsupported wave file format')