Coverage for skema/img2mml/models/encoding/positional_features_for_cnn_encoder.py: 22%

18 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import math 

2import torch 

3 

4 

5def get_range_vector(size: int, device) -> torch.Tensor: 

6 return torch.arange(0, size, dtype=torch.long, device=device) 

7 

8 

9def add_positional_features( 

10 tensor: torch.Tensor, 

11 min_timescale: float = 1.0, 

12 max_timescale: float = 1.0e4, 

13): 

14 """ 

15 Implements the frequency-based positional encoding described 

16 in `Attention is all you Need 

17 Parameters 

18 ---------- 

19 tensor : ``torch.Tensor`` 

20 a Tensor with shape (batch_size, timesteps, hidden_dim). 

21 min_timescale : ``float``, optional (default = 1.0) 

22 The largest timescale to use. 

23 Returns 

24 ------- 

25 The input tensor augmented with the sinusoidal frequencies. 

26 """ 

27 _, timesteps, hidden_dim = tensor.size() 

28 

29 timestep_range = get_range_vector(timesteps, tensor.device).data.float() 

30 # We're generating both cos and sin frequencies, 

31 # so half for each. 

32 num_timescales = hidden_dim // 2 

33 timescale_range = get_range_vector( 

34 num_timescales, tensor.device 

35 ).data.float() 

36 

37 log_timescale_increments = math.log( 

38 float(max_timescale) / float(min_timescale) 

39 ) / float(num_timescales - 1) 

40 inverse_timescales = min_timescale * torch.exp( 

41 timescale_range * -log_timescale_increments 

42 ) 

43 

44 # Broadcasted multiplication - shape (timesteps, num_timescales) 

45 scaled_time = timestep_range.unsqueeze(1) * inverse_timescales.unsqueeze(0) 

46 # shape (timesteps, 2 * num_timescales) 

47 sinusoids = torch.randn( 

48 scaled_time.size(0), 2 * scaled_time.size(1), device=tensor.device 

49 ) 

50 sinusoids[:, ::2] = torch.sin(scaled_time) 

51 sinusoids[:, 1::2] = torch.cos(scaled_time) 

52 if hidden_dim % 2 != 0: 

53 # if the number of dimensions is odd, the cos and sin 

54 # timescales had size (hidden_dim - 1) / 2, so we need 

55 # to add a row of zeros to make up the difference. 

56 sinusoids = torch.cat( 

57 [sinusoids, sinusoids.new_zeros(timesteps, 1)], 1 

58 ) 

59 return tensor + sinusoids.unsqueeze(0)