Loss revamp & Renamed model to network
This commit is contained in:
parent
0d35d607fe
commit
9d2a30a226
7 changed files with 44 additions and 32 deletions
59
network/csrnet.py
Normal file
59
network/csrnet.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Stolen from https://github.com/leeyeehoo/CSRNet-pytorch.git
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from torchvision import models
|
||||
from utils import save_net,load_net
|
||||
|
||||
class CSRNet(nn.Module):
|
||||
def __init__(self, load_weights=False):
|
||||
super(CSRNet, self).__init__()
|
||||
|
||||
# Ref. 2018 paper
|
||||
self.seen = 0
|
||||
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
|
||||
self.backend_feat = [512, 512, 512, 256, 128, 64] # 4-parallel, 1, 2, 2-then-4, 4 dilation rates
|
||||
|
||||
self.frontend = make_layers(self.frontend_feat)
|
||||
self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
|
||||
self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
|
||||
if not load_weights:
|
||||
mod = models.vgg16(pretrained = True)
|
||||
self._initialize_weights()
|
||||
for i in range(len(self.frontend.state_dict().items())):
|
||||
self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]
|
||||
|
||||
def forward(self,x):
|
||||
x = self.frontend(x)
|
||||
x = self.backend(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def make_layers(cfg, in_channels = 3, batch_norm=False, dilation=False):
|
||||
if dilation:
|
||||
d_rate = 2
|
||||
else:
|
||||
d_rate = 1
|
||||
layers = []
|
||||
for v in cfg:
|
||||
if v == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = v
|
||||
return nn.Sequential(*layers)
|
||||
Loading…
Add table
Add a link
Reference in a new issue