Coverage for skema/img2mml/models/encoding/positional_encoding_for_xfmer.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1""" 

2PositionalEncoding class has been taken from PyTorch tutorials. 

3<Source>: https://pytorch.org/tutorials/beginner/transformer_tutorial.html 

4""" 

5 

6import torch 

7import torch.nn as nn 

8import math 

9 

10 

11class PositionalEncoding(nn.Module): 

12 def __init__(self, model_dimension, dropout, max_len): 

13 super(PositionalEncoding, self).__init__() 

14 self.dropout = nn.Dropout(p=dropout) 

15 pe = torch.zeros(max_len, model_dimension) # (max_len, model_dim) 

16 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze( 

17 1 

18 ) # (max_len, 1) 

19 div_term = torch.exp( 

20 torch.arange(0, model_dimension, 2).float() 

21 * (-math.log(10000.0) / model_dimension) 

22 ) # ([model_dim//2]) 

23 pe[:, 0::2] = torch.sin(position * div_term) # (max_len, model_dim//2) 

24 pe[:, 1::2] = torch.cos(position * div_term) # (max_len, model_dim//2) 

25 pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, model_dim) 

26 self.register_buffer("pe", pe) 

27 

28 def forward(self, x: torch.tensor) -> torch.tensor: 

29 # x: (max_len, B, embed_dim) 

30 # print("x shape:", x.shape) 

31 # print("x_ shape:", self.pe[:x.size(0), :].shape) 

32 x = x + self.pe[: x.size(0), :] 

33 return self.dropout(x)