Loading [MathJax]/jax/output/CommonHTML/jax.js
본문 바로가기

프로그래밍/python

tf.transpose 함수 사용하기

텐서플로의 transpose 함수는 행렬 연산의 transpose 를 구현하는 함수이다.

함수의 사용법은 아래와 같고,

tf.transpose(
    a
, perm=None, conjugate=False, name='transpose'
)

a A Tensor.
perm A permutation of the dimensions of a. This should be a vector.
conjugate Optional bool. Setting it to True is mathematically equivalent to tf.math.conj(tf.transpose(input)).
name A name for the operation (optional).

Args

A transposed Tensor.

Returns

Numpy Compatibility

In numpy transposes are memory-efficient constant time operations as they simply return a new view of the same data with adjusted strides.

TensorFlow does not support strides, so transpose returns a new tensor with the items permuted.

 

python에서의 사용 예를 보면

  1. import tensorflow as tf

  2. x = tf.constant([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])

  3. print("x shape is ",tf.shape(x))

  4.  

  5. print("Default transpose")

  6. print(tf.transpose(x))

  7.  

  8. print("Costom transpose")

  9. print(tf.transpose(x,perm=[0,2,1]))

와 같고 이를 실행하면 다음과 같은 결과가 나온다.

  1. x shape is  tf.Tensor([2 2 3], shape=(3,), dtype=int32)

  2. Default transpose

  3. tf.Tensor(

  4. [[[ 1  7]

  5.   [ 4 10]]

  6.  

  7.  [[ 2  8]

  8.   [ 5 11]]

  9.  

  10.  [[ 3  9]

  11.   [ 6 12]]], shape=(3, 2, 2), dtype=int32)

  12. Costom transpose

  13. tf.Tensor(

  14. [[[ 1  4]

  15.   [ 2  5]

  16.   [ 3  6]]

  17.  

  18.  [[ 7 10]

  19.   [ 8 11]

  20.   [ 9 12]]], shape=(2, 3, 2), dtype=int32)

 

위 예에서 텐서의 사이즈를 보면 2x2x3 인데, 첫번째 2가 0축, 두번째 2가 1축 세번째 3이 2축이 각각 된다. perm 파라미터를 생략하면 perm = [2,1,0] 가 기본 값으로 들어가고, 이는 0번째 축과  2번째 축을 바꾸는 오퍼레이션을 하겠다는 의미기다.

 

이해가 잘 안가면 텐서의 각 값에 인덱스를 붙여 보면 이해가 쉽다.

 

X=[111121123113412151226123721182129213102211122212223]

 

여기서 X의  0번축과 2번축을 바꾸는 transpose를 구해보면

Xt=[111121123113412151226123721182129213102211122212223]t

이것은 각 원소들의 atijk=akji로 바꾸는 연산을 수행하는 것이다.