覚え書きブログ

ブロック単位のreshape その2

hirotaka-hachiya.hatenablog.com
前回のブロック単位でのreshapeはある特定の行列の形にしか対応できなかったが、以下のサイトによるとnumpy.hsplitとnumpy.vstackを組み合わせるとどんな行列にも対応できそう。
stackoverflow.com

以下のように、4行6列の行列をnumpy.hsplitで行方向に3つに分割し、そのあとnumpy.vstackを用いて縦方向に繋げる。

> z=np.arange(24).reshape(4,6)
> z
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23]])
> np.hsplit(z,3)
[array([[ 0,  1],
       [ 6,  7],
       [12, 13],
       [18, 19]]), array([[ 2,  3],
       [ 8,  9],
       [14, 15],
       [20, 21]]), array([[ 4,  5],
       [10, 11],
       [16, 17],
       [22, 23]])]
> np.vstack(np.hsplit(z,3))
array([[ 0,  1],
       [ 6,  7],
       [12, 13],
       [18, 19],
       [ 2,  3],
       [ 8,  9],
       [14, 15],
       [20, 21],
       [ 4,  5],
       [10, 11],
       [16, 17],
       [22, 23]])

そうすると、2行2列のブロック単位で縦に並べることができる。例えば、[[0,1],[6,7]]や[[2,3],[8,9]]がブロックである。
しかしながら、[[0,1],[6,7]]の後に本来は[[2,3],[8,9]]が来てほしいところ、元々下にある[[12,13],[18,19]]が来てしまっていて、ブロックの順番が期待通りにはなっていなかった。

そこで、今度は、numpy.vsplitを用いて縦方向に2つに分割してから、各固まりをnumpy.hsplitを用いてそれぞれ横方向に3つに分割してから、numpy.vstackで縦方向に並べてみた。
リスト内包表記が入り、少し複雑にはなるが以下のようになる。

>np.vstack([np.hsplit(np.vsplit(z,2)[i],3) for i in range(2)]).reshape(12,2)
array([[ 0,  1],
       [ 6,  7],
       [ 2,  3],
       [ 8,  9],
       [ 4,  5],
       [10, 11],
       [12, 13],
       [18, 19],
       [14, 15],
       [20, 21],
       [16, 17],
       [22, 23]])

リスト内包表記なしでもnumpy.hstackを使えばおなじことができた。

>np.vstack(np.hsplit(np.hstack(np.vsplit(z,2)),6))
array([[ 0,  1],
       [ 6,  7],
       [ 2,  3],
       [ 8,  9],
       [ 4,  5],
       [10, 11],
       [12, 13],
       [18, 19],
       [14, 15],
       [20, 21],
       [16, 17],
       [22, 23]])

同じ原理で、tensorflowでもsplitとconcatを用いてブロック単位でのreshapeができる。

> z=tf.constant(z)
> z
<tf.Tensor: shape=(4, 6), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23]])>
> tf.concat(tf.split(tf.concat(tf.split(z,2,axis=0),axis=1),6,axis=1),axis=0)
<tf.Tensor: shape=(12, 2), dtype=int32, numpy=
array([[ 0,  1],
       [ 6,  7],
       [ 2,  3],
       [ 8,  9],
       [ 4,  5],
       [10, 11],
       [12, 13],
       [18, 19],
       [14, 15],
       [20, 21],
       [16, 17],
       [22, 23]])>