首页 Paddle框架 帖子详情
einsum-计算不同维度的矩阵乘法怎么实现
收藏
快速回复
Paddle框架 问答深度学习 1428 3
einsum-计算不同维度的矩阵乘法怎么实现
收藏
快速回复
Paddle框架 问答深度学习 1428 3

我需要实现PyTorch中einsum函数的一个功能,pytorch代码如下:

torch.einsum('bkhw,bckhw->bchw', [W, X])

四维矩阵与五维矩阵相乘,两个矩阵第0维和后两个维度相同,如代码所示,其余维度k维度中元素对应相乘后相加,将此标量放在c维度中,最后输出维度为bchw。

请问用paddle如何实现,求各路大神帮忙想想办法?

 

0
收藏
回复
全部评论(3)
时间顺序
l
love梦的畅想
#2 回复于2021-10
def einsum( subscripts, *matrices ):

    def combo( n ):
        combination = [ 0 ] * len( n )
        while ( True ):
            yield combination
            p = len( n ) - 1;        
            combination[ p ] += 1
            while ( combination[ p ] == n[ p ] ):
                combination[ p ] = 0
                p -= 1
                if ( p < 0 ):
                    return
                combination[ p ] += 1

    subscripts = subscripts.split( "->" )
    opt_idx_str = subscripts[ 0 ]
    res_idx_str = subscripts[ 1 ]
    opt_idx_str = opt_idx_str.split( "," )
    
    all_idx = [ ]
    idx2dim = { }
    matrix2idx = { }
    for i, idx_str in enumerate( opt_idx_str ):
        matrix = matrices[ i ]
        idx_list = list( idx_str.strip( ) )
        matrix2idx[ id( matrix ) ] = [ ]
        
        for j, each_idx in enumerate( idx_list ):
            if each_idx not in all_idx:
                all_idx.append( each_idx )
            if each_idx not in idx2dim.keys():
                idx2dim[ each_idx ] = matrix.shape[ j ]
            elif idx2dim[ each_idx ] != matrix.shape[ j ]:
                # print(idx2dim, each_idx, matrix.shape[ j ])
                raise('idx2dim error!')
            matrix2idx[ id( matrix ) ].append( each_idx )
    res_idx = list(res_idx_str)
    res_dim = [idx2dim[idx] for idx in res_idx]
    result = paddle.zeros( [ idx2dim[ idx ] for idx in list( res_idx ) ] )
    matrix2idx[ id(result) ] = res_idx
    non_idx = []
    for idx in all_idx:
        if idx not in res_idx:
            non_idx.append(idx)

    idxTracker = {}
    for idx in all_idx:
        idxTracker[idx] = slice(0,idx2dim[idx]) if idx in non_idx else 0
    
    for idxTuple in combo( res_dim ):
        temp = paddle.ones([ idx2dim[ idx ] for idx in non_idx ])
        for i, idx in enumerate(res_idx):
            idxTracker[ idx ] = idxTuple[ i ]
        for matrix in matrices:
            # print(idxTracker)
            matIndices = tuple( [ idxTracker[ idx ] for idx in matrix2idx[ id( matrix ) ] ] )
            matrix = matrix[ matIndices ]
            # print(matrix.shape, matIndices)
            if matIndices != [ idx2dim[ idx ] for idx in non_idx ]:
                # print(matrix.shape, [ idx2dim[ idx ] for idx in non_idx ])
                matrix = paddle.reshape(matrix, [ idx2dim[ idx ] for idx in non_idx ])
            temp *= matrix
        matIndices = tuple( [ idxTracker[ idx ] for idx in matrix2idx[ id( result ) ] ] )
        result[ matIndices ] += temp.sum()
            
    return result
0
回复
l
love梦的畅想
#3 回复于2021-10

python版,没有并行优化,后续会改进成算子的

0
回复
l
love梦的畅想
#4 回复于2021-10

啊我错了,还不能反向传播

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户