프로그래밍/python

tf.transpose 함수 사용하기

ksyoon 2021. 1. 11. 22:51

텐서플로의 transpose 함수는 행렬 연산의 transpose 를 구현하는 함수이다.

함수의 사용법은 아래와 같고,

tf.transpose(
    a
, perm=None, conjugate=False, name='transpose'
)

a A Tensor.
perm A permutation of the dimensions of a. This should be a vector.
conjugate Optional bool. Setting it to True is mathematically equivalent to tf.math.conj(tf.transpose(input)).
name A name for the operation (optional).

Args

A transposed Tensor.

Returns

Numpy Compatibility

In numpy transposes are memory-efficient constant time operations as they simply return a new view of the same data with adjusted strides.

TensorFlow does not support strides, so transpose returns a new tensor with the items permuted.

 

python에서의 사용 예를 보면

  1. import tensorflow as tf

  2. x = tf.constant([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])

  3. print("x shape is ",tf.shape(x))

  4.  

  5. print("Default transpose")

  6. print(tf.transpose(x))

  7.  

  8. print("Costom transpose")

  9. print(tf.transpose(x,perm=[0,2,1]))

와 같고 이를 실행하면 다음과 같은 결과가 나온다.

  1. x shape is  tf.Tensor([2 2 3], shape=(3,), dtype=int32)

  2. Default transpose

  3. tf.Tensor(

  4. [[[ 1  7]

  5.   [ 4 10]]

  6.  

  7.  [[ 2  8]

  8.   [ 5 11]]

  9.  

  10.  [[ 3  9]

  11.   [ 6 12]]], shape=(3, 2, 2), dtype=int32)

  12. Costom transpose

  13. tf.Tensor(

  14. [[[ 1  4]

  15.   [ 2  5]

  16.   [ 3  6]]

  17.  

  18.  [[ 7 10]

  19.   [ 8 11]

  20.   [ 9 12]]], shape=(2, 3, 2), dtype=int32)

 

위 예에서 텐서의 사이즈를 보면 2x2x3 인데, 첫번째 2가 0축, 두번째 2가 1축 세번째 3이 2축이 각각 된다. perm 파라미터를 생략하면 perm = [2,1,0] 가 기본 값으로 들어가고, 이는 0번째 축과  2번째 축을 바꾸는 오퍼레이션을 하겠다는 의미기다.

 

이해가 잘 안가면 텐서의 각 값에 인덱스를 붙여 보면 이해가 쉽다.

 

$$ \begin{split} X &=\begin{bmatrix} 1_{111} & 2_{112}  & 3_{113} \\4_{121} & 5_{122}  & 6_{123} \\ 7_{211} & 8_{212}  & 9_{213}\\ 10_{221} & 11_{222}  & 12_{223} \end{bmatrix} \\ \end{split}$$

 

여기서 \(X\)의  0번축과 2번축을 바꾸는 transpose를 구해보면

$$ \begin{split} X^t &=\begin{bmatrix} 1_{111} & 2_{112}  & 3_{113} \\4_{121} & 5_{122}  & 6_{123} \\ 7_{211} & 8_{212}  & 9_{213}\\ 10_{221} & 11_{222}  & 12_{223} \end{bmatrix}^t \\ \end{split}$$

이것은 각 원소들의 \(a^t_{ijk} = a_{kji} \)로 바꾸는 연산을 수행하는 것이다.