LibTorch短时傅里叶变换

0

/**
 * 短时傅里叶变换
 * 
 * 201 = win_size / 2 + 1
 * 480 = 7 | 4800 = 61 | 48000 = 601
 * [1, 201, 61, 2[实部, 虚部]]
 * 
 * @param pcm      PCM数据
 * @param n_fft    傅里叶变换的大小
 * @param hop_size 相邻滑动窗口帧之间的距离
 * @param win_size 窗口帧和STFT滤波器的大小
 * 
 * @return 张量
 */
extern torch::Tensor pcm_stft(
    std::vector<short>& pcm,
    int n_fft    = 400,
    int hop_size = 80,
    int win_size = 400
);

/**
 * 短时傅里叶逆变换
 * 
 * @param tensor   张量
 * @param n_fft    傅里叶变换的大小
 * @param hop_size 相邻滑动窗口帧之间的距离
 * @param win_size 窗口帧和STFT滤波器的大小
 * 
 * @return PCM数据
 */
extern std::vector<short> pcm_istft(
    const torch::Tensor& tensor,
    int n_fft    = 400,
    int hop_size = 80,
    int win_size = 400
);

torch::Tensor lifuren::dataset::audio::pcm_stft(
    std::vector<short>& pcm,
    int n_fft,
    int hop_size,
    int win_size
) {
    auto data = torch::from_blob(pcm.data(), { 1, static_cast<int>(pcm.size()) }, torch::kShort).to(torch::kFloat32) / NORMALIZATION;
    auto wind = torch::hann_window(win_size);
    auto real = torch::view_as_real(torch::stft(data, n_fft, hop_size, win_size, wind, true, "reflect", false, std::nullopt, true));
    // 幅度: sqrt(x^2 + y^2)
    auto mag = torch::sqrt(real.pow(2).sum(-1));
    // 相位: atan2(y, x)
    auto pha = torch::atan2(real.index({ "...", 1 }), real.index({ "...", 0 }));
    return torch::stack({ mag, pha }, -1).squeeze();
}

std::vector<short> lifuren::dataset::audio::pcm_istft(
    const torch::Tensor& tensor,
    int n_fft,
    int hop_size,
    int win_size
) {
    auto copy = tensor.unsqueeze(0);
    auto wind = torch::hann_window(win_size);
    auto mag  = copy.index({ "...", 0 });
    auto pha  = copy.index({ "...", 1 });
    auto com  = torch::complex(mag * torch::cos(pha), mag * torch::sin(pha));
    auto ret  = torch::istft(com, n_fft, hop_size, win_size, wind, true) * NORMALIZATION;
    float* data = reinterpret_cast<float*>(ret.data_ptr());
    std::vector<short> pcm;
    pcm.resize(ret.sizes()[1]);
    std::copy_n(data, pcm.size(), pcm.data());
    return pcm;
}