it-swarm.com.de

Wie dekodiert man One-Hot-Labels in Tensorflow?

Ich habe nachgeschaut, kann aber anscheinend keine Beispiele dafür finden, wie man einen One-Hot-Wert in TensorFlow dekodiert oder zurück in eine einzelne Ganzzahl konvertiert.

Ich habe tf.one_hot verwendet und konnte mein Modell trainieren, bin aber etwas verwirrt darüber, wie ich das Etikett nach meiner Klassifizierung verstehen soll. Meine Daten werden über eine von mir erstellte TFRecords -Datei eingespeist. Ich habe darüber nachgedacht, ein Textetikett in der Datei zu speichern, konnte es aber nicht zum Laufen bringen. Es sah so aus, als ob TFRecords keine Textzeichenfolge speichern könnte oder ich mich geirrt habe.

4
Matt Camp

Sie können den Index des größten Elements in der Matrix mit tf.argmax ermitteln. Da Ihr einziger heißer Vektor eindimensional ist und nur einen 1 und einen anderen 0 hat, funktioniert dies unter der Annahme, dass Sie sich mit einem einzelnen Vektor befassen.

index = tf.argmax(one_hot_vector, axis=0)

Für die Standardmatrix von batch_size * num_classes verwenden Sie axis=1, um ein Ergebnis der Größe batch_size * 1 zu erhalten.

12
martianwars

Da eine One-Hot-Codierung in der Regel nur eine Matrix mit batch_size Zeilen und num_classes Spalten ist und jede Zeile Null ist, wobei eine einzige Nicht-Null der ausgewählten Klasse entspricht, können Sie tf.argmax() verwenden, um einen Ganzzahlvektor wiederherzustellen Etiketten:

BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
                               [1, 0, 0, 0],
                               [0, 0, 0, 1]])

# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)

# ...
print sess.run(decoded)  # ==> array([1, 0, 3])
7
mrry
data = np.array([1, 5, 3, 8])
print(data)


def encode(data):
    print('Shape of data (BEFORE encode): %s' % str(data.shape))
    encoded = to_categorical(data)
    print('Shape of data (AFTER  encode): %s\n' % str(encoded.shape))
    return encoded


encoded_data = encode(data)
print(encoded_data)

def decode(datum):
    return np.argmax(datum)

decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
    datum = encoded_data[i]
    print('index: %d' % i)
    print('encoded datum: %s' % datum)
    decoded_datum = decode(encoded_data[i])
    print('decoded datum: %s' % decoded_datum)
    decoded_Y.append(decoded_datum)


print("****************************************")

print(decoded_Y)
0
Rochan