看不懂我吃屎系列——CTPN

Posted by 高庆东 on May 30, 2019

CTPN

主要是作用是来检测图片上的文字对象,并画一个框出来,目的是画出可能是文字的框,下一步的输入可以是ocr识别,
先说一下他的网络结果, 其很就是在FasterRCNN的基础上增加一个RNN的结构,看了一下之前的FasterRCNN好像
有些没说明白,这次全部写了。从结构到代码。结合代码说原理。它的大结果快主要有VGG,RNN,RPN 分类回归,非 极大值抑制用来处理最终预测结果

先介绍一个我们需要的训练数据和处理结果

输出当然是一张图片和图片的标签,也就是图像的框坐标数据,在CTPN中没有类别数据因为判断是不是文字直接就是
二分类问题。先说图片,首先读取一张图片然后读标签坐标然后生成锚点框坐标,为啥进行这个操作呢,因为我们经 过VGG后生成的特征图与原始图像有一个比例关系,通过这个比例来计算锚点个数和框坐标。具体操作看代码吧

#使用torch写代码所以你懂的,写一个处理数据的子类,该类可以生成 *一个* 样本的训练数据
class VOCDataset(Dataset):
    def __init__(self,
                 datadir,
                 labelsdir):
        '''
        :param txtfile: image name list text file
        :param datadir: image's directory
        :param labelsdir: annotations' directory
        '''
        if not os.path.isdir(datadir):
            raise Exception('[ERROR] {} is not a directory'.format(datadir))
        if not os.path.isdir(labelsdir):
            raise Exception('[ERROR] {} is not a directory'.format(labelsdir))

        self.datadir = datadir
        self.img_names = os.listdir(self.datadir)
        self.labelsdir = labelsdir

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.datadir, img_name)
        print(img_path)
        xml_path = os.path.join(self.labelsdir, img_name.replace('.jpg', '.xml'))
        gtbox, _ = readxml(xml_path)
        img = cv2.imread(img_path)
        h, w, c = img.shape
        # clip image
        if np.random.randint(2) == 1:
            img = img[:, ::-1, :]
            newx1 = w - gtbox[:, 2] - 1
            newx2 = w - gtbox[:, 0] - 1
            gtbox[:, 0] = newx1
            gtbox[:, 2] = newx2

        [cls, regr], _ = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox)  #[labels, bbox_targets], base_anchor
        #cls表示经过iou和其他方式筛选的类别里面主要是-1,0,1  #regr表示所有锚点框回归的关系数
        m_img = img - IMAGE_MEAN

        regr = np.hstack([cls.reshape(cls.shape[0], 1), regr])

        cls = np.expand_dims(cls, axis=0)  #扩展维度

        # transform to torch tensor
        m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float()  
        cls = torch.from_numpy(cls).float()
        regr = torch.from_numpy(regr).float()

        return m_img, cls, regr   #data中的coll_fn哪个函数默认的时候会自动再外面加一维,为batchsize维

看看计算锚点的代码,这个代码的逻辑是,首先定义一个锚点对应10个不同比例的框,然后计算一共有多少锚点
然后计算每个锚点框的在原图的坐标再原图上标出来

def gen_anchor(featuresize, scale):
    """
        gen base anchor from feature map [HXW][9][4]
        reshape  [HXW][9][4] to [HXWX9][4]
    """
    #设定不同尺寸的锚点框每个锚点对应10个框
    heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283]
    widths = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16]

    # gen k=9 anchor size (h,w)
    heights = np.array(heights).reshape(len(heights), 1)
    widths = np.array(widths).reshape(len(widths), 1)
    # 计算一个锚点框的中心坐标和大小,
    base_anchor = np.array([0, 0, 15, 15])
    # center x,y
    xt = (base_anchor[0] + base_anchor[2]) * 0.5
    yt = (base_anchor[1] + base_anchor[3]) * 0.5
    #计算一个锚点对应的所有框的的坐标,
    # x1 y1 x2 y2
    x1 = xt - widths * 0.5
    y1 = yt - heights * 0.5
    x2 = xt + widths * 0.5
    y2 = yt + heights * 0.5
    base_anchor = np.hstack((x1, y1, x2, y2))  #水平方向堆叠
    #将锚点分布到原图上时锚点的坐标
    h, w = featuresize
    shift_x = np.arange(0, w) * scale
    shift_y = np.arange(0, h) * scale
    # apply shift
    anchor = []
    for i in shift_y:
        for j in shift_x:
            # 画出所有在原图上的锚点框的坐标
            #这里区分一下,锚点时一个缩放后中心点,这个中心点对应10个框就是锚点框。
            anchor.append(base_anchor + [j, i, j, i])
    return np.array(anchor).reshape((-1, 4))

计算每个锚点框针对实际边框的IOU值,加入一张图片上有3个框,那就拿出一个锚点框,分别计算每个的IOU,
结果存储,遍历所有的锚点框然后会生成一个矩阵矩阵元素表示第几个锚点框与第几个实际框的IOU的值。

def cal_overlaps(boxes1, boxes2):
    """
    boxes1 [Nsample,x1,y1,x2,y2]  anchor
    boxes2 [Msample,x1,y1,x2,y2]  grouth-box

    """
    area1 = (boxes1[:, 0] - boxes1[:, 2]) * (boxes1[:, 1] - boxes1[:, 3])
    area2 = (boxes2[:, 0] - boxes2[:, 2]) * (boxes2[:, 1] - boxes2[:, 3])

    overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))

    # calculate the intersection of  boxes1(anchor) and boxes2(GT box)
    for i in range(boxes1.shape[0]):   #遍历每个锚点框
        overlaps[i][:] = cal_iou(boxes1[i], area1[i], boxes2, area2)  #依次比较每个IOU
    return overlaps

通过某些过滤条件过滤掉一些框,过滤条件有IOU阈值等等,顺便为一些框打标签这个标签主要依据就是IOU阈值,大于某个阈值认为是正标签定义为1
小于某个阈值定义为0 其他的阈值定义为-1。

    anchor_argmax_overlaps = overlaps.argmax(axis=1)  # 每个锚点对应的GT的最大iou的位置
    anchor_max_overlaps = overlaps[range(overlaps.shape[0]), anchor_argmax_overlaps]  #每个锚点对应所有的GT的最大iou的位置

    # IOU > IOU_POSITIVE
    labels[anchor_max_overlaps > IOU_POSITIVE] = 1
    # IOU <IOU_NEGATIVE
    labels[anchor_max_overlaps < IOU_NEGATIVE] = 0
    # ensure that every GT box has at least one positive RPN region
    labels[gt_argmax_overlaps] = 1

找多所有框对框打一个标签后我们需要再找,每个锚点框对应的实际框,因为锚点框很多每个锚点框会
和每个实际框计算IOU最大IOU就是这个锚点框对应的实际框。计算锚点框变化为实际框时的变化值,
这个值就是回归的目标值,你品是不是哪里不太对。宏观上看我们在一张图上画出很多框,这些框
会变化成实际框,变化量是回归的值的标签,AW=G A表示所有锚点框,W表示变化系数,G表示实际框。 首先需要计算这个W,通过A和G来计算

def bbox_transfrom(anchors, gtboxes): #输入是所有锚点框和每个框对应的GT框

    regr = np.zeros((anchors.shape[0], 2))   #只对y做回归
    Cy = (gtboxes[:, 1] + gtboxes[:, 3]) * 0.5   #当[:,1]会降维 变成1行
    Cya = (anchors[:, 1] + anchors[:, 3]) * 0.5
    h = gtboxes[:, 3] - gtboxes[:, 1] + 1.0
    ha = anchors[:, 3] - anchors[:, 1] + 1.0

    Vc = (Cy - Cya) / ha
    Vh = np.log(h / ha)

    return np.vstack((Vc, Vh)).transpose() #按行堆叠数组 #转置   #返回GT与anchors的变化关系

最终VOCDataset()返回值有三个第一个是图片数据。第二个是label,是每个锚点框的标签,
有0,1,-1,第三个返回值是回归的值与label的合并,回归的值是每个锚点框变化成对应G的 对应W,好了数据部分处理完了,该模型了,模型简单

模型结构

首先是VGG做特征提取这部分不多说了,现在看下一部分 ,输出特征图后会经过两条路劲一个是
做分类一个是做回归

class BasicConv(nn.Module):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 relu=True,
                 bn=True,
                 bias=True):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU(inplace=True) if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class CTPN_Model(nn.Module):
    def __init__(self):
        super().__init__()
        base_model = models.vgg16(pretrained=False)
        layers = list(base_model.features)[:-1]
        self.base_layers = nn.Sequential(*layers)  # block5_conv3 output
        self.rpn = BasicConv(512, 512, 3,1,1,bn=False)
        self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True)  #这一步比较巧妙,其实就是对结果做形状变化。做文字方向的均分
        self.lstm_fc = BasicConv(256, 512,1,1,relu=True, bn=False)
        self.rpn_class = BasicConv(512, 10*2, 1, 1, relu=False,bn=False)
        self.rpn_regress = BasicConv(512, 10 * 2, 1, 1, relu=False, bn=False)

    def forward(self, x):
        x = self.base_layers(x)
        # rpn
        x = self.rpn(x)

        x1 = x.permute(0,2,3,1).contiguous()  # channels last
        b = x1.size()  # batch_size, h, w, c
        x1 = x1.view(b[0]*b[1], b[2], b[3])

        x2, _ = self.brnn(x1)  #经过GRU的结果

        xsz = x.size()  #batch_size, c, h, w
        x3 = x2.view(xsz[0], xsz[2], xsz[3], 256)  # torch.Size([4, 20, 20, 256])

        x3 = x3.permute(0,3,1,2).contiguous()  # channels first
        x3 = self.lstm_fc(x3)
        x = x3

        cls = self.rpn_class(x)
        regr = self.rpn_regress(x)

        cls = cls.permute(0,2,3,1).contiguous()
        regr = regr.permute(0,2,3,1).contiguous()

        cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2)
        regr = regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2)  #

        return cls, regr #输出的是每个锚点框的类别和变化系数W,
     

一张图片经过特征网络后输出有锚点框的类别和变化系数,我们使用的就是变化系数做回归回归
的结果可以这样理解。 在训练数据上我们知道一张图片的所有锚点框A,同时知道实际框G,我们
可以计算出变化系数W,W会作为标签。然后实际的图片经过网络会得到每个锚点框的变化系数 W_
让这个变化系数无限趋近于W这就是网络的作用,具体的操作中会更具VOCDataset()给出的标签
数据对网络输出的框进行筛选,然后再回归计算更新网络,但是这个网络在预测时并没有筛选过程,
只通过类别概率进行了筛选。


    dataset = VOCDataset(args['image_dir'], args['labels_dir'])  #读数据
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args['num_workers'])
    model = CTPN_Model()  #导入模型
    model.to(device)
    
    if os.path.exists(checkpoints_weight):
        print('using pretrained weight: {}'.format(checkpoints_weight))
        cc = torch.load(checkpoints_weight, map_location=device)
        model.load_state_dict(cc['model_state_dict'])
        resume_epoch = cc['epoch']

    params_to_uodate = model.parameters()
    optimizer = optim.SGD(params_to_uodate, lr=lr, momentum=0.9)
    
    critetion_cls = RPN_CLS_Loss(device)  
    critetion_regr = RPN_REGR_Loss(device)  #计算回归损失
    
    best_loss_cls = 100
    best_loss_regr = 100
    best_loss = 100
    best_model = None
    epochs += resume_epoch
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) #动态调整学习率
    
    for epoch in range(resume_epoch+1, epochs):
        print(f'Epoch {epoch}/{epochs}')
        print('#'*50)
        epoch_size = len(dataset) // 1
        model.train()
        epoch_loss_cls = 0
        epoch_loss_regr = 0
        epoch_loss = 0
        scheduler.step(epoch)
    
        for batch_i, (imgs, clss, regrs) in enumerate(dataloader):
            imgs = imgs.to(device)
            clss = clss.to(device)
            regrs = regrs.to(device)
    
            optimizer.zero_grad()
    
            out_cls, out_regr = model(imgs)
            loss_cls = critetion_cls(out_cls, clss)
            loss_regr = critetion_regr(out_regr, regrs)  #计算回归损失
    
            loss = loss_cls + loss_regr  # total loss
            loss.backward()
            optimizer.step()

回归损失是这样写的

class RPN_CLS_Loss(nn.Module):
    def __init__(self,device):
        super(RPN_CLS_Loss, self).__init__()
        self.device = device

    def forward(self, input, target):
        y_true = target[0][0]
        cls_keep = (y_true != -1).nonzero()[:, 0] #会根据标签筛掉一些框,再计算回归值,也就是说,我们得到的模型必须输入比较好的框才能得到好的结果,但是实际预测时,我们输入的是全部的框得到结果,所以这点不合理
        cls_true = y_true[cls_keep].long()   
        cls_pred = input[0][cls_keep]
        loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true)  # original is sparse_softmax_cross_entropy_with_logits
        # loss = nn.BCEWithLogitsLoss()(cls_pred[:,0], cls_true.float())  # 18-12-8
        loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0) #限制再0到10之间
        return loss.to(self.device)