嵌入
0
class SinusoidalPositionEmbedding(nn.Module):
def __init__(
self,
dim : int,
max_len: int = 512
):
super().__init__()
pe = torch.zeros(max_len, dim, dtype = torch.float)
position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype = torch.float) * (-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0), persistent = False)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input + self.pe[:, :input.size(1)]
class RotaryPositionEmbedding(nn.Module):
def __init__(
self,
dim : int,
max_len : int = 512,
num_heads: int = 8
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype = torch.float) / dim))
t = torch.arange(0, max_len, dtype = torch.float)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim = -1)
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0), persistent = False)
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0), persistent = False)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1 = x[..., : d // 2]
x2 = x[..., d // 2 :]
return torch.cat((-x2, x1), dim = -1)
def forward(self, query: torch.Tensor, key: torch.Tensor) -> List[torch.Tensor]:
q = query.view(query.size(0), query.size(1), self.num_heads, self.dim).transpose(1, 2)
k = key .view(key .size(0), key .size(1), self.num_heads, self.dim).transpose(1, 2)
cos = self.cos_cached[:, :, :q.size(2), :].expand_as(q)
sin = self.sin_cached[:, :, :k.size(2), :].expand_as(k)
q_rotated = self.rotate_half(q)
k_rotated = self.rotate_half(k)
q_embed = (q * cos) + (q_rotated * sin)
k_embed = (k * cos) + (k_rotated * sin)
return \
[
q_embed.transpose(1, 2).view(query.size(0), query.size(1), -1),
k_embed.transpose(1, 2).view(key .size(0), key .size(1), -1)
]
/**
* 位置嵌入
*/
class SinusoidalPositionEmbeddingImpl : public torch::nn::Module {
private:
torch::Tensor pe{ nullptr };
public:
SinusoidalPositionEmbeddingImpl(int64_t dim, int64_t max_len = 512) {
torch::Tensor pe = torch::zeros({ max_len, dim }, torch::kFloat);
torch::Tensor position = torch::arange(0, max_len, torch::kFloat).unsqueeze(1);
torch::Tensor div_term = torch::exp(torch::arange(0, dim, 2, torch::kFloat) * (-std::log(10000.0) / static_cast<double>(dim)));
// pe.slice(1, 0, dim, 2) = torch::sin(position * div_term.unsqueeze(0));
// pe.slice(1, 1, dim, 2) = torch::cos(position * div_term.unsqueeze(0));
pe.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, torch::indexing::None, 2) }, torch::sin(position * div_term));
pe.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(1, torch::indexing::None, 2) }, torch::cos(position * div_term));
this->pe = this->register_buffer("pe", pe.unsqueeze(0));
}
public:
torch::Tensor forward(torch::Tensor input) {
// input + this->pe.slice (1, 0, input.size(1));
return input + this->pe.narrow(1, 0, input.size(1));
}
};
TORCH_MODULE(SinusoidalPositionEmbedding);
/**
* 位置嵌入
*/
class RotaryPositionEmbeddingImpl : public torch::nn::Module {
private:
int64_t dim = 0;
int64_t num_heads = 8;
torch::Tensor sin_cached{ nullptr };
torch::Tensor cos_cached{ nullptr };
public:
RotaryPositionEmbeddingImpl(int64_t dim, int64_t max_len = 512, int64_t num_heads = 8) : dim(dim), num_heads(num_heads) {
torch::Tensor inv_freq = 1.0 / torch::pow(10000.0, torch::arange(0, dim, 2, torch::kFloat) / static_cast<double>(dim));
torch::Tensor t = torch::arange(0, max_len, torch::kFloat);
torch::Tensor freqs = torch::outer(t, inv_freq);
torch::Tensor emb = torch::cat({freqs, freqs}, -1);
this->sin_cached = this->register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0));
this->cos_cached = this->register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0));
}
private:
torch::Tensor rotate_half(const torch::Tensor& x) {
int64_t d = x.size(-1);
// auto x1 = x.slice(-1, 0, d / 2);
// auto x2 = x.slice(-1, d / 2, d);
auto x1 = x.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, d / 2) });
auto x2 = x.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(d / 2, torch::indexing::None) });
return torch::cat({ -x2, x1 }, -1);
}
public:
std::tuple<torch::Tensor, torch::Tensor> forward(const torch::Tensor& query, const torch::Tensor& key) {
auto q = query.view({ query.size(0), query.size(1), this->num_heads, this->dim }).transpose(1, 2);
auto k = key .view({ key .size(0), key .size(1), this->num_heads, this->dim }).transpose(1, 2);
// auto cos_slice = this->cos_cached.slice(2, 0, q.size(2));
// auto sin_slice = this->sin_cached.slice(2, 0, k.size(2));
auto cos_slice = this->cos_cached.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, q.size(2)), torch::indexing::Slice() });
auto sin_slice = this->sin_cached.index({ torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, k.size(2)), torch::indexing::Slice() });
auto cos_expanded = cos_slice.expand_as(q);
auto sin_expanded = sin_slice.expand_as(k);
auto q_rotated = this->rotate_half(q);
auto k_rotated = this->rotate_half(k);
auto q_embed = (q * cos_expanded) + (q_rotated * sin_expanded);
auto k_embed = (k * cos_expanded) + (k_rotated * sin_expanded);
return std::make_tuple(
q_embed.transpose(1, 2).view({ query.size(0), query.size(1), -1 }),
k_embed.transpose(1, 2).view({ key .size(0), key .size(1), -1 })
);
}
};
TORCH_MODULE(RotaryPositionEmbedding);