Abstract

视频理解领域,高准确率和低计算成本的权衡一直是一个十分重要的问题。传统的2D卷积计算量很小,但是只能捕捉图像中的空间信息,无法捕捉帧与帧之间的时序信息。基于3D卷积的方法虽然效果很好,但是计算成本太高,实际部署困难。这篇文章中,我们提出了一种泛化性强、效率高的模块(Temporal Shift Module),能够同时实现高效和高精度,能够实现3D卷积精度的基础上,仍然保持2D卷积的计算成本。TSM在时间维度上面移动了部分通道,便于在相邻帧图像之间交换信息。TSM模块可以插入2D卷积模型中,在不增加任何计算成本的条件下实现时间建模。我们还将TSM应用在在线任务中,实现了实时低延迟的在线视频识别和视频对象检测。TSM准确而有效,在Something-Something数据集榜单上面排行第一,在Galaxy Note8手机上面,实现了35ms低延迟的在线视频识别。代码可以在这里找到

TSM模块

TSM通过沿着时间维度移动特征图通道来执行有效的时间建模。相较于2D卷积,并没有增加计算量,但是TSM结构具有强大的时间建模能力。TSM能够有效支持在线和离线视频分类任务,双向TSM将过去的帧和未来的帧以及当前帧进行混合,适应于高吞吐量的离线视频识别。单向TSM将过去的帧与当前帧进行混合,适应于低延迟的在线视频识别

具体来讲,视频任务的每一层feature map我们可以将其视为是一个$R^{NxCxTxHxW}$矩阵,其中N代表batch size,C代表的是feature的通道数,T代表的是时间维度,H和W代表的是空间维度。传统的2D卷积沿着维度T独立应用在每一帧图像上,因此没有利用时间信息(图a所示)。相比下,TSM沿着时间维度同时向前向后移动channel(图b所示),通过移动channel,相邻帧图像的信息与当前帧进行了混合。这种思想的出发点是:卷积运算包括位移和乘法累加,我们将时间维度偏移+-1,并将乘积从时间维度折叠到通道维度。为了实现在线视频理解,将来的帧不能与当前帧进行融合,因此提出了单向TSM(图c所示)

考虑一个简单的卷积操作,这里为了简化,采用1维卷积,kernel size设置为3。假设卷积的权重是$W=(w_1, w_2, w_3)$,输入是向量$X$,卷积操作可以表示为$Y=Conv(W, X)$,分解$Y_i = w_1X_{i-1}+w_2X_i+w_3X_{i+1}$,我们可以将卷积操作分解为两步:位移(shift)以及乘法累加(multiply-accumulate),我们利用-1, 0, +1位移$X$,然后利用$w_1, w_2, w_3$做乘积,累加之后得到对应的$Y$。shift操作表示为:

$$X_{i}^{-1}=X_{i-1}, X_i^0=X_i, X_i^{+1}=X_{i+1}$$

乘法累加操作表示为

$$Y=w_1X^{-1}+w_2X^0+w_3X^{+1}$$

第一步shift操作可以在没有任何乘法的条件下实现,第二步的乘法累加计算量更大,TSM模块将乘法累加操作融合到了接下来的2D卷积中去,因此相较于原来的2D卷积操作,并没有造成任何额外的计算损失

代码实现

class TemporalShift(nn.Module):
    def __init__(self, net, n_segment=3, n_div=8, inplace=False):
        super(TemporalShift, self).__init__()
        self.net = net
        self.n_segment = n_segment
        self.fold_div = n_div
        self.inplace = inplace
        if inplace:
            print('=> Using in-place shift...')
        print('=> Using fold div: {}'.format(self.fold_div))

    def forward(self, x):
        x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
        return self.net(x)

    @staticmethod
    def shift(x, n_segment, fold_div=3, inplace=False):
        nt, c, h, w = x.size()
        n_batch = nt // n_segment
        x = x.view(n_batch, n_segment, c, h, w)

        fold = c // fold_div
        if inplace:
            # Due to some out of order error when performing parallel computing. 
            # May need to write a CUDA kernel.
            raise NotImplementedError  
            # out = InplaceShift.apply(x, fold)
        else:
            out = torch.zeros_like(x)
            out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
            out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
            out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift

        return out.view(nt, c, h, w)


论文

本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!