Coverage for skema/img2mml/models/encoders/cnn_encoder.py: 71%
38 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import torch
2import torch.nn as nn
3from skema.img2mml.models.encoding.positional_features_for_cnn_encoder import (
4 add_positional_features,
5)
6from skema.img2mml.models.encoding.row_encoding import RowEncoding
9class CNN_Encoder(nn.Module):
10 def __init__(self, input_channels, dec_hid_dim, dropout, device):
11 """
12 :param input_channels: input channels of source image
13 :param embed_dim: embedding size
14 :param hid_dim: size of decoder's RNN
15 :param enc_dim: feature size of encoded images
16 :param dropout: dropout
17 :param device: device to be used
18 """
19 super(CNN_Encoder, self).__init__()
21 self.device = device
22 self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(self.device)
23 self.kernel = (3, 3)
24 self.padding = (1, 1)
25 self.stride = (1, 1)
26 self.re = RowEncoding(device, dec_hid_dim, dropout)
27 self.linear = nn.Linear(512, dec_hid_dim)
29 self.cnn_encoder = nn.Sequential(
30 # layer 1: [batch, Cin, w, h]
31 nn.Conv2d(
32 input_channels,
33 64,
34 kernel_size=self.kernel,
35 stride=self.stride,
36 padding=self.padding,
37 ),
38 nn.ReLU(),
39 nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
40 # layer 2
41 nn.Conv2d(
42 64,
43 128,
44 kernel_size=self.kernel,
45 stride=self.stride,
46 padding=self.padding,
47 ),
48 nn.BatchNorm2d(128),
49 nn.ReLU(),
50 nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
51 # layer 3
52 nn.Conv2d(
53 128,
54 256,
55 kernel_size=self.kernel,
56 stride=self.stride,
57 padding=self.padding,
58 ),
59 nn.BatchNorm2d(256),
60 nn.ReLU(),
61 # layer 4: [B, 256, w, h]
62 nn.Conv2d(
63 256,
64 256,
65 kernel_size=self.kernel,
66 stride=self.stride,
67 padding=self.padding,
68 ),
69 nn.ReLU(),
70 nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)),
71 # layer 5
72 nn.Conv2d(
73 256,
74 512,
75 kernel_size=self.kernel,
76 stride=self.stride,
77 padding=self.padding,
78 ),
79 nn.BatchNorm2d(512),
80 nn.ReLU(),
81 # layer 6:[B, 512, 10, 33]
82 nn.Conv2d(
83 512,
84 512,
85 kernel_size=self.kernel,
86 stride=self.stride,
87 padding=self.padding,
88 ),
89 nn.ReLU(),
90 )
91 self.init_weights()
93 def init_weights(self):
94 """
95 initializing the model wghts with values
96 drawn from normal distribution.
97 else initialize them with 0.
98 """
99 for name, param in self.cnn_encoder.named_parameters():
100 if "nn.Conv2d" in name or "nn.Linear" in name:
101 if "weight" in name:
102 nn.init.normal_(param.data, mean=0, std=0.1)
103 elif "bias" in name:
104 nn.init.constant_(param.data, 0)
105 elif "nn.BatchNorm2d" in name:
106 if "weight" in name:
107 nn.init.constant_(param.data, 1)
108 elif "bias" in name:
109 nn.init.constant_(param.data, 0)
111 def forward(self, src, encoding_type=None):
112 output = self.cnn_encoder(src) # (B, 512, W, H)
114 if encoding_type == "row_encoding":
115 # output: (B, H*W, dec_hid_dim)
116 # hidden, cell: [1, B, dec_hid_dim]
117 output, hidden, cell = self.re(output)
118 return output, hidden, cell
120 else:
121 output = torch.flatten(output, 2, -1) # (B, 512, L=H*W)
122 output = output.permute(0, 2, 1) # (B, L, 512)
123 if encoding_type == "positional_features":
124 output += add_positional_features(output) # (B, L, 512)
125 return self.linear(output) # (B, L, dec_hid_dim)