libtorch stft分离音频
0
const int n_fft = 400;
const int hop_size = 80;
const int win_size = 400;
static auto wind = torch::hann_window(win_size).to(tensor.device());
// 分解幅度和相位
auto com = torch::stft(tensor.squeeze(-1), n_fft, hop_size, win_size, wind, true, "reflect", false, std::nullopt, true);
auto mag = torch::abs(com);
auto pha = torch::angle(com);
// 合成幅度和相位
auto low_mask = torch::zeros({ mag.size(0), mag.size(1), 1 }).to(tensor.device());
auto mid_mask = torch::zeros({ mag.size(0), mag.size(1), 1 }).to(tensor.device());
auto high_mask = torch::zeros({ mag.size(0), mag.size(1), 1 }).to(tensor.device());
low_mask .index({ torch::indexing::Slice(), torch::indexing::Slice( 0, 10), torch::indexing::Slice() }) = 1.0;
mid_mask .index({ torch::indexing::Slice(), torch::indexing::Slice( 10, 100), torch::indexing::Slice() }) = 1.0;
high_mask.index({ torch::indexing::Slice(), torch::indexing::Slice(100 ), torch::indexing::Slice() }) = 1.0;
auto low = torch::istft(torch::polar(mag * low_mask, pha), n_fft, hop_size, win_size, wind, true).unsqueeze(-1);
auto mid = torch::istft(torch::polar(mag * mid_mask, pha), n_fft, hop_size, win_size, wind, true).unsqueeze(-1);
auto high = torch::istft(torch::polar(mag * high_mask, pha), n_fft, hop_size, win_size, wind, true).unsqueeze(-1);
return {
low,
mid,
high
};
[[maybe_unused]] static void test_get_low_mid_high() {
std::ifstream stream_in ("D:/tmp/dzht.pcm", std::ios::binary);
std::ofstream stream_low ("D:/tmp/dzht_low.pcm", std::ios::binary);
std::ofstream stream_mid ("D:/tmp/dzht_mid.pcm", std::ios::binary);
std::ofstream stream_high("D:/tmp/dzht_high.pcm", std::ios::binary);
std::ofstream stream_out ("D:/tmp/dzht_.pcm", std::ios::binary);
int size = 800;
std::vector<short> pcm(size);
while(stream_in.read((char*) pcm.data(), sizeof(short) * size)) {
auto tensor = torch::from_blob(pcm.data(), { 1, size }, torch::kShort).to(torch::kFloat32).div(32768.0);
auto [low, mid, high] = chobits::model::get_low_mid_high(tensor);
low = low .mul(32768.0).to(torch::kShort);
mid = mid .mul(32768.0).to(torch::kShort);
high = high.mul(32768.0).to(torch::kShort);
auto out = low + mid + high;
stream_low .write((char*) low .data_ptr(), sizeof(short) * size);
stream_mid .write((char*) mid .data_ptr(), sizeof(short) * size);
stream_high.write((char*) high.data_ptr(), sizeof(short) * size);
stream_out .write((char*) out .data_ptr(), sizeof(short) * size);
}
}