utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import json
  2. import numpy as np
  3. import torch
  4. from tqdm import tqdm
  5. def load_data(file_name: str = "./lib/name_params.json") -> dict:
  6. with open(file_name, "r") as f:
  7. data = json.load(f)
  8. return data
  9. def make_padding(width, cropsize, offset):
  10. left = offset
  11. roi_size = cropsize - left * 2
  12. if roi_size == 0:
  13. roi_size = cropsize
  14. right = roi_size - (width % roi_size) + left
  15. return left, right, roi_size
  16. def inference(X_spec, device, model, aggressiveness, data):
  17. """
  18. data : dic configs
  19. """
  20. def _execute(
  21. X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half=True
  22. ):
  23. model.eval()
  24. with torch.no_grad():
  25. preds = []
  26. iterations = [n_window]
  27. total_iterations = sum(iterations)
  28. for i in tqdm(range(n_window)):
  29. start = i * roi_size
  30. X_mag_window = X_mag_pad[
  31. None, :, :, start : start + data["window_size"]
  32. ]
  33. X_mag_window = torch.from_numpy(X_mag_window)
  34. if is_half:
  35. X_mag_window = X_mag_window.half()
  36. X_mag_window = X_mag_window.to(device)
  37. pred = model.predict(X_mag_window, aggressiveness)
  38. pred = pred.detach().cpu().numpy()
  39. preds.append(pred[0])
  40. pred = np.concatenate(preds, axis=2)
  41. return pred
  42. def preprocess(X_spec):
  43. X_mag = np.abs(X_spec)
  44. X_phase = np.angle(X_spec)
  45. return X_mag, X_phase
  46. X_mag, X_phase = preprocess(X_spec)
  47. coef = X_mag.max()
  48. X_mag_pre = X_mag / coef
  49. n_frame = X_mag_pre.shape[2]
  50. pad_l, pad_r, roi_size = make_padding(n_frame, data["window_size"], model.offset)
  51. n_window = int(np.ceil(n_frame / roi_size))
  52. X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
  53. if list(model.state_dict().values())[0].dtype == torch.float16:
  54. is_half = True
  55. else:
  56. is_half = False
  57. pred = _execute(
  58. X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half
  59. )
  60. pred = pred[:, :, :n_frame]
  61. if data["tta"]:
  62. pad_l += roi_size // 2
  63. pad_r += roi_size // 2
  64. n_window += 1
  65. X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
  66. pred_tta = _execute(
  67. X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half
  68. )
  69. pred_tta = pred_tta[:, :, roi_size // 2 :]
  70. pred_tta = pred_tta[:, :, :n_frame]
  71. return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.0j * X_phase)
  72. else:
  73. return pred * coef, X_mag, np.exp(1.0j * X_phase)
  74. def _get_name_params(model_path, model_hash):
  75. data = load_data()
  76. flag = False
  77. ModelName = model_path
  78. for type in list(data):
  79. for model in list(data[type][0]):
  80. for i in range(len(data[type][0][model])):
  81. if str(data[type][0][model][i]["hash_name"]) == model_hash:
  82. flag = True
  83. elif str(data[type][0][model][i]["hash_name"]) in ModelName:
  84. flag = True
  85. if flag:
  86. model_params_auto = data[type][0][model][i]["model_params"]
  87. param_name_auto = data[type][0][model][i]["param_name"]
  88. if type == "equivalent":
  89. return param_name_auto, model_params_auto
  90. else:
  91. flag = False
  92. return param_name_auto, model_params_auto