Coverage for skema/img2mml/models/encoding/positional_features_for_cnn_encoder.py: 22%
18 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import math
2import torch
5def get_range_vector(size: int, device) -> torch.Tensor:
6 return torch.arange(0, size, dtype=torch.long, device=device)
9def add_positional_features(
10 tensor: torch.Tensor,
11 min_timescale: float = 1.0,
12 max_timescale: float = 1.0e4,
13):
14 """
15 Implements the frequency-based positional encoding described
16 in `Attention is all you Need
17 Parameters
18 ----------
19 tensor : ``torch.Tensor``
20 a Tensor with shape (batch_size, timesteps, hidden_dim).
21 min_timescale : ``float``, optional (default = 1.0)
22 The largest timescale to use.
23 Returns
24 -------
25 The input tensor augmented with the sinusoidal frequencies.
26 """
27 _, timesteps, hidden_dim = tensor.size()
29 timestep_range = get_range_vector(timesteps, tensor.device).data.float()
30 # We're generating both cos and sin frequencies,
31 # so half for each.
32 num_timescales = hidden_dim // 2
33 timescale_range = get_range_vector(
34 num_timescales, tensor.device
35 ).data.float()
37 log_timescale_increments = math.log(
38 float(max_timescale) / float(min_timescale)
39 ) / float(num_timescales - 1)
40 inverse_timescales = min_timescale * torch.exp(
41 timescale_range * -log_timescale_increments
42 )
44 # Broadcasted multiplication - shape (timesteps, num_timescales)
45 scaled_time = timestep_range.unsqueeze(1) * inverse_timescales.unsqueeze(0)
46 # shape (timesteps, 2 * num_timescales)
47 sinusoids = torch.randn(
48 scaled_time.size(0), 2 * scaled_time.size(1), device=tensor.device
49 )
50 sinusoids[:, ::2] = torch.sin(scaled_time)
51 sinusoids[:, 1::2] = torch.cos(scaled_time)
52 if hidden_dim % 2 != 0:
53 # if the number of dimensions is odd, the cos and sin
54 # timescales had size (hidden_dim - 1) / 2, so we need
55 # to add a row of zeros to make up the difference.
56 sinusoids = torch.cat(
57 [sinusoids, sinusoids.new_zeros(timesteps, 1)], 1
58 )
59 return tensor + sinusoids.unsqueeze(0)