tf.gather_nd is really slow when used for many times
tf.gather_nd is really slow when used for many times
I would like a loss function in tensorflow which is a complex combination of many elements. For example, this code:
import tensorflow as tf
import numpy as np
import time
input_layer = tf.placeholder(tf.float64, shape=[64,4])
output_layer = input_layer + 0.5*tf.tanh(tf.Variable(tf.random_uniform(shape=[64,4],
minval=-1,maxval=1,dtype=tf.float64)))
# random_combination is 2-d numpy array of the form:
# [[32, 34, 23, 56],[23,54,33,21],...]
random_combination = np.random.randint(64, size=(210000000, 4))
# a collector to collect the values
collector=
print('start looping')
print(time.asctime(time.localtime(time.time())))
# loop through random_combination and pick the elements of output_layer
for i in range(len(random_combination)):
[i,j,k,l] = [random_combination[i][0],random_combination[i][1],
random_combination[i][2],random_combination[i][3]]
# pick the needed element from output_layer
f1 = tf.gather_nd(output_layer,[i,0])
f2 = tf.gather_nd(output_layer,[i,2])
f3 = tf.gather_nd(output_layer,[i,3])
f4 = tf.gather_nd(output_layer,[i,4])
tf1 = f1+1
tf2 = f2+1
tf3 = f3+1
tf4 = f4+1
collector.append(0.3*tf.abs(f1*f2*tf3*tf4-tf1*tf2*f3*f4))
print('end looping')
print(time.asctime(time.localtime(time.time())))
# loss function
loss = tf.add_n(collector)
This takes around 50 minutes on my computer.
My question is that is it the proper way to do the coding in tensorflow?
Or there is a more time efficient way to index the elements?
By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.