LibTorch短时傅里叶变换
0
/**
* 短时傅里叶变换
*
* 201 = win_size / 2 + 1
* 480 = 7 | 4800 = 61 | 48000 = 601
* [1, 201, 61, 2[实部, 虚部]]
*
* @param pcm PCM数据
* @param n_fft 傅里叶变换的大小
* @param hop_size 相邻滑动窗口帧之间的距离
* @param win_size 窗口帧和STFT滤波器的大小
*
* @return 张量
*/
extern torch::Tensor pcm_stft(
std::vector<short>& pcm,
int n_fft = 400,
int hop_size = 80,
int win_size = 400
);
/**
* 短时傅里叶逆变换
*
* @param tensor 张量
* @param n_fft 傅里叶变换的大小
* @param hop_size 相邻滑动窗口帧之间的距离
* @param win_size 窗口帧和STFT滤波器的大小
*
* @return PCM数据
*/
extern std::vector<short> pcm_istft(
const torch::Tensor& tensor,
int n_fft = 400,
int hop_size = 80,
int win_size = 400
);
torch::Tensor lifuren::dataset::audio::pcm_stft(
std::vector<short>& pcm,
int n_fft,
int hop_size,
int win_size
) {
auto data = torch::from_blob(pcm.data(), { 1, static_cast<int>(pcm.size()) }, torch::kShort).to(torch::kFloat32) / NORMALIZATION;
auto wind = torch::hann_window(win_size);
auto real = torch::view_as_real(torch::stft(data, n_fft, hop_size, win_size, wind, true, "reflect", false, std::nullopt, true));
// 幅度: sqrt(x^2 + y^2)
auto mag = torch::sqrt(real.pow(2).sum(-1));
// 相位: atan2(y, x)
auto pha = torch::atan2(real.index({ "...", 1 }), real.index({ "...", 0 }));
return torch::stack({ mag, pha }, -1).squeeze();
}
std::vector<short> lifuren::dataset::audio::pcm_istft(
const torch::Tensor& tensor,
int n_fft,
int hop_size,
int win_size
) {
auto copy = tensor.unsqueeze(0);
auto wind = torch::hann_window(win_size);
auto mag = copy.index({ "...", 0 });
auto pha = copy.index({ "...", 1 });
auto com = torch::complex(mag * torch::cos(pha), mag * torch::sin(pha));
auto ret = torch::istft(com, n_fft, hop_size, win_size, wind, true) * NORMALIZATION;
float* data = reinterpret_cast<float*>(ret.data_ptr());
std::vector<short> pcm;
pcm.resize(ret.sizes()[1]);
std::copy_n(data, pcm.size(), pcm.data());
return pcm;
}