59 lines
No EOL
2.1 KiB
Python
59 lines
No EOL
2.1 KiB
Python
# Stolen from https://github.com/leeyeehoo/CSRNet-pytorch.git
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
from torchvision import models
|
|
from utils import save_net,load_net
|
|
|
|
class CSRNet(nn.Module):
|
|
def __init__(self, load_weights=False):
|
|
super(CSRNet, self).__init__()
|
|
|
|
# Ref. 2018 paper
|
|
self.seen = 0
|
|
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
|
|
self.backend_feat = [512, 512, 512, 256, 128, 64] # 4-parallel, 1, 2, 2-then-4, 4 dilation rates
|
|
|
|
self.frontend = make_layers(self.frontend_feat)
|
|
self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
|
|
self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
|
|
if not load_weights:
|
|
mod = models.vgg16(pretrained = True)
|
|
self._initialize_weights()
|
|
for i in range(len(self.frontend.state_dict().items())):
|
|
self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]
|
|
|
|
def forward(self,x):
|
|
x = self.frontend(x)
|
|
x = self.backend(x)
|
|
x = self.output_layer(x)
|
|
return x
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.normal_(m.weight, std=0.01)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
def make_layers(cfg, in_channels = 3, batch_norm=False, dilation=False):
|
|
if dilation:
|
|
d_rate = 2
|
|
else:
|
|
d_rate = 1
|
|
layers = []
|
|
for v in cfg:
|
|
if v == 'M':
|
|
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
|
else:
|
|
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
|
|
if batch_norm:
|
|
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
|
else:
|
|
layers += [conv2d, nn.ReLU(inplace=True)]
|
|
in_channels = v
|
|
return nn.Sequential(*layers) |