嵌入

0

class SinusoidalPositionEmbedding(nn.Module):
    def __init__(
        self,
        dim    : int,
        max_len: int = 512
    ):
        super().__init__()
        pe = torch.zeros(max_len, dim, dtype = torch.float)
        position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2, dtype = torch.float) * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0), persistent = False)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input + self.pe[:, :input.size(1)]

class RotaryPositionEmbedding(nn.Module):
    def __init__(
        self,
        dim      : int,
        max_len  : int = 512,
        num_heads: int = 8
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype = torch.float) / dim))
        t = torch.arange(0, max_len, dtype = torch.float)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat((freqs, freqs), dim = -1)
        self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0), persistent = False)
        self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0), persistent = False)

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1]
        x1 = x[..., : d // 2]
        x2 = x[..., d // 2 :]
        return torch.cat((-x2, x1), dim = -1)

    def forward(self, query: torch.Tensor, key: torch.Tensor) -> List[torch.Tensor]:
        q = query.view(query.size(0), query.size(1), self.num_heads, self.dim).transpose(1, 2)
        k = key  .view(key  .size(0), key  .size(1), self.num_heads, self.dim).transpose(1, 2)
        cos = self.cos_cached[:, :, :q.size(2), :].expand_as(q)
        sin = self.sin_cached[:, :, :k.size(2), :].expand_as(k)
        q_rotated = self.rotate_half(q)
        k_rotated = self.rotate_half(k)
        q_embed = (q * cos) + (q_rotated * sin)
        k_embed = (k * cos) + (k_rotated * sin)
        return \
            [
                q_embed.transpose(1, 2).view(query.size(0), query.size(1), -1),
                k_embed.transpose(1, 2).view(key  .size(0), key  .size(1), -1)
            ]

/**
 * 位置嵌入
 */
class SinusoidalPositionEmbeddingImpl : public torch::nn::Module {

private:
    torch::Tensor pe{ nullptr };

public:
    SinusoidalPositionEmbeddingImpl(int64_t dim, int64_t max_len = 512) {
        torch::Tensor pe = torch::zeros({ max_len, dim }, torch::kFloat);
        torch::Tensor position = torch::arange(0, max_len, torch::kFloat).unsqueeze(1);
        torch::Tensor div_term = torch::exp(torch::arange(0, dim, 2, torch::kFloat) * (-std::log(10000.0) / static_cast<double>(dim)));
        // pe.slice(1, 0, dim, 2) = torch::sin(position * div_term.unsqueeze(0));
        // pe.slice(1, 1, dim, 2) = torch::cos(position * div_term.unsqueeze(0));
        pe.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, torch::indexing::None, 2) }, torch::sin(position * div_term));
        pe.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(1,                     torch::indexing::None, 2) }, torch::cos(position * div_term));
        this->pe = this->register_buffer("pe", pe.unsqueeze(0));
    }

public:
    torch::Tensor forward(torch::Tensor input) {
//             input + this->pe.slice (1, 0, input.size(1));
        return input + this->pe.narrow(1, 0, input.size(1));
    }

};

TORCH_MODULE(SinusoidalPositionEmbedding);

/**
 * 位置嵌入
 */
class RotaryPositionEmbeddingImpl : public torch::nn::Module {

private:
    int64_t dim       = 0;
    int64_t num_heads = 8;
    torch::Tensor sin_cached{ nullptr };
    torch::Tensor cos_cached{ nullptr };

public:
    RotaryPositionEmbeddingImpl(int64_t dim, int64_t max_len = 512, int64_t num_heads = 8) : dim(dim), num_heads(num_heads) {
        torch::Tensor inv_freq = 1.0 / torch::pow(10000.0, torch::arange(0, dim, 2, torch::kFloat) / static_cast<double>(dim));
        torch::Tensor t = torch::arange(0, max_len, torch::kFloat);
        torch::Tensor freqs = torch::outer(t, inv_freq);
        torch::Tensor emb = torch::cat({freqs, freqs}, -1);
        this->sin_cached = this->register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0));
        this->cos_cached = this->register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0));
    }

private:
    torch::Tensor rotate_half(const torch::Tensor& x) {
        int64_t d = x.size(-1);
        // auto x1 = x.slice(-1, 0, d / 2);
        // auto x2 = x.slice(-1, d / 2, d);
        auto x1 = x.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, d / 2) });
        auto x2 = x.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(d / 2, torch::indexing::None) });
        return torch::cat({ -x2, x1 }, -1);
    }

public:
    std::tuple<torch::Tensor, torch::Tensor> forward(const torch::Tensor& query, const torch::Tensor& key) {
        auto q = query.view({ query.size(0), query.size(1), this->num_heads, this->dim }).transpose(1, 2);
        auto k = key  .view({ key  .size(0), key  .size(1), this->num_heads, this->dim }).transpose(1, 2);
        // auto cos_slice = this->cos_cached.slice(2, 0, q.size(2));
        // auto sin_slice = this->sin_cached.slice(2, 0, k.size(2));
        auto cos_slice = this->cos_cached.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, q.size(2)), torch::indexing::Slice() });
        auto sin_slice = this->sin_cached.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, k.size(2)), torch::indexing::Slice() });
        auto cos_expanded = cos_slice.expand_as(q);
        auto sin_expanded = sin_slice.expand_as(k);
        auto q_rotated = this->rotate_half(q);
        auto k_rotated = this->rotate_half(k);
        auto q_embed = (q * cos_expanded) + (q_rotated * sin_expanded);
        auto k_embed = (k * cos_expanded) + (k_rotated * sin_expanded);
        return std::make_tuple(
            q_embed.transpose(1, 2).view({ query.size(0), query.size(1), -1 }),
            k_embed.transpose(1, 2).view({ key  .size(0), key  .size(1), -1 })
        );
    }

};

TORCH_MODULE(RotaryPositionEmbedding);