Coverage for skema/img2mml/models/encoders/cnn_encoder.py: 71%

38 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import torch 

2import torch.nn as nn 

3from skema.img2mml.models.encoding.positional_features_for_cnn_encoder import ( 

4 add_positional_features, 

5) 

6from skema.img2mml.models.encoding.row_encoding import RowEncoding 

7 

8 

9class CNN_Encoder(nn.Module): 

10 def __init__(self, input_channels, dec_hid_dim, dropout, device): 

11 """ 

12 :param input_channels: input channels of source image 

13 :param embed_dim: embedding size 

14 :param hid_dim: size of decoder's RNN 

15 :param enc_dim: feature size of encoded images 

16 :param dropout: dropout 

17 :param device: device to be used 

18 """ 

19 super(CNN_Encoder, self).__init__() 

20 

21 self.device = device 

22 self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(self.device) 

23 self.kernel = (3, 3) 

24 self.padding = (1, 1) 

25 self.stride = (1, 1) 

26 self.re = RowEncoding(device, dec_hid_dim, dropout) 

27 self.linear = nn.Linear(512, dec_hid_dim) 

28 

29 self.cnn_encoder = nn.Sequential( 

30 # layer 1: [batch, Cin, w, h] 

31 nn.Conv2d( 

32 input_channels, 

33 64, 

34 kernel_size=self.kernel, 

35 stride=self.stride, 

36 padding=self.padding, 

37 ), 

38 nn.ReLU(), 

39 nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), 

40 # layer 2 

41 nn.Conv2d( 

42 64, 

43 128, 

44 kernel_size=self.kernel, 

45 stride=self.stride, 

46 padding=self.padding, 

47 ), 

48 nn.BatchNorm2d(128), 

49 nn.ReLU(), 

50 nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), 

51 # layer 3 

52 nn.Conv2d( 

53 128, 

54 256, 

55 kernel_size=self.kernel, 

56 stride=self.stride, 

57 padding=self.padding, 

58 ), 

59 nn.BatchNorm2d(256), 

60 nn.ReLU(), 

61 # layer 4: [B, 256, w, h] 

62 nn.Conv2d( 

63 256, 

64 256, 

65 kernel_size=self.kernel, 

66 stride=self.stride, 

67 padding=self.padding, 

68 ), 

69 nn.ReLU(), 

70 nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)), 

71 # layer 5 

72 nn.Conv2d( 

73 256, 

74 512, 

75 kernel_size=self.kernel, 

76 stride=self.stride, 

77 padding=self.padding, 

78 ), 

79 nn.BatchNorm2d(512), 

80 nn.ReLU(), 

81 # layer 6:[B, 512, 10, 33] 

82 nn.Conv2d( 

83 512, 

84 512, 

85 kernel_size=self.kernel, 

86 stride=self.stride, 

87 padding=self.padding, 

88 ), 

89 nn.ReLU(), 

90 ) 

91 self.init_weights() 

92 

93 def init_weights(self): 

94 """ 

95 initializing the model wghts with values 

96 drawn from normal distribution. 

97 else initialize them with 0. 

98 """ 

99 for name, param in self.cnn_encoder.named_parameters(): 

100 if "nn.Conv2d" in name or "nn.Linear" in name: 

101 if "weight" in name: 

102 nn.init.normal_(param.data, mean=0, std=0.1) 

103 elif "bias" in name: 

104 nn.init.constant_(param.data, 0) 

105 elif "nn.BatchNorm2d" in name: 

106 if "weight" in name: 

107 nn.init.constant_(param.data, 1) 

108 elif "bias" in name: 

109 nn.init.constant_(param.data, 0) 

110 

111 def forward(self, src, encoding_type=None): 

112 output = self.cnn_encoder(src) # (B, 512, W, H) 

113 

114 if encoding_type == "row_encoding": 

115 # output: (B, H*W, dec_hid_dim) 

116 # hidden, cell: [1, B, dec_hid_dim] 

117 output, hidden, cell = self.re(output) 

118 return output, hidden, cell 

119 

120 else: 

121 output = torch.flatten(output, 2, -1) # (B, 512, L=H*W) 

122 output = output.permute(0, 2, 1) # (B, L, 512) 

123 if encoding_type == "positional_features": 

124 output += add_positional_features(output) # (B, L, 512) 

125 return self.linear(output) # (B, L, dec_hid_dim)