Coverage for skema/img2mml/models/encoding/positional_encoding_for_xfmer.py: 100%
17 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1"""
2PositionalEncoding class has been taken from PyTorch tutorials.
3<Source>: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
4"""
6import torch
7import torch.nn as nn
8import math
11class PositionalEncoding(nn.Module):
12 def __init__(self, model_dimension, dropout, max_len):
13 super(PositionalEncoding, self).__init__()
14 self.dropout = nn.Dropout(p=dropout)
15 pe = torch.zeros(max_len, model_dimension) # (max_len, model_dim)
16 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
17 1
18 ) # (max_len, 1)
19 div_term = torch.exp(
20 torch.arange(0, model_dimension, 2).float()
21 * (-math.log(10000.0) / model_dimension)
22 ) # ([model_dim//2])
23 pe[:, 0::2] = torch.sin(position * div_term) # (max_len, model_dim//2)
24 pe[:, 1::2] = torch.cos(position * div_term) # (max_len, model_dim//2)
25 pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, model_dim)
26 self.register_buffer("pe", pe)
28 def forward(self, x: torch.tensor) -> torch.tensor:
29 # x: (max_len, B, embed_dim)
30 # print("x shape:", x.shape)
31 # print("x_ shape:", self.pe[:x.size(0), :].shape)
32 x = x + self.pe[: x.size(0), :]
33 return self.dropout(x)