libtorch stft分离音频

0

    const int n_fft    = 400;
    const int hop_size = 80;
    const int win_size = 400;
    static auto wind = torch::hann_window(win_size).to(tensor.device());
    // 分解幅度和相位
    auto com = torch::stft(tensor.squeeze(-1), n_fft, hop_size, win_size, wind, true, "reflect", false, std::nullopt, true);
    auto mag = torch::abs(com);
    auto pha = torch::angle(com);
    // 合成幅度和相位
    auto low_mask  = torch::zeros({ mag.size(0), mag.size(1), 1 }).to(tensor.device());
    auto mid_mask  = torch::zeros({ mag.size(0), mag.size(1), 1 }).to(tensor.device());
    auto high_mask = torch::zeros({ mag.size(0), mag.size(1), 1 }).to(tensor.device());
    low_mask .index({ torch::indexing::Slice(), torch::indexing::Slice(  0,  10), torch::indexing::Slice() }) = 1.0;
    mid_mask .index({ torch::indexing::Slice(), torch::indexing::Slice( 10, 100), torch::indexing::Slice() }) = 1.0;
    high_mask.index({ torch::indexing::Slice(), torch::indexing::Slice(100     ), torch::indexing::Slice() }) = 1.0;
    auto low  = torch::istft(torch::polar(mag * low_mask,  pha), n_fft, hop_size, win_size, wind, true).unsqueeze(-1);
    auto mid  = torch::istft(torch::polar(mag * mid_mask,  pha), n_fft, hop_size, win_size, wind, true).unsqueeze(-1);
    auto high = torch::istft(torch::polar(mag * high_mask, pha), n_fft, hop_size, win_size, wind, true).unsqueeze(-1);
    return {
        low,
        mid,
        high
    };

[[maybe_unused]] static void test_get_low_mid_high() {
    std::ifstream stream_in  ("D:/tmp/dzht.pcm",      std::ios::binary);
    std::ofstream stream_low ("D:/tmp/dzht_low.pcm",  std::ios::binary);
    std::ofstream stream_mid ("D:/tmp/dzht_mid.pcm",  std::ios::binary);
    std::ofstream stream_high("D:/tmp/dzht_high.pcm", std::ios::binary);
    std::ofstream stream_out ("D:/tmp/dzht_.pcm",     std::ios::binary);
    int size = 800;
    std::vector<short> pcm(size);
    while(stream_in.read((char*) pcm.data(), sizeof(short) * size)) {
        auto tensor = torch::from_blob(pcm.data(), { 1, size }, torch::kShort).to(torch::kFloat32).div(32768.0);
        auto [low, mid, high] = chobits::model::get_low_mid_high(tensor);
        low  = low .mul(32768.0).to(torch::kShort);
        mid  = mid .mul(32768.0).to(torch::kShort);
        high = high.mul(32768.0).to(torch::kShort);
        auto out = low + mid + high;
        stream_low .write((char*) low .data_ptr(), sizeof(short) * size);
        stream_mid .write((char*) mid .data_ptr(), sizeof(short) * size);
        stream_high.write((char*) high.data_ptr(), sizeof(short) * size);
        stream_out .write((char*) out .data_ptr(), sizeof(short) * size);
    }
}