标签:keras git -o sum ada its date nump step
目录
Keras是一个框架
datasets
layers
losses
metrics
optimizers
Metrics
update_state
result().numpy()
reset_states
acc_meter = metrics.Accuarcy()
loss_meter = metrics.Mean
loss_meter.update_state(loss)
acc_meter.update_state(y,pred)
print(step, 'loss:', loss_meter.result().numpy())
# ...
print(step,'Evaluate Acc:', total_correct/total, acc_meter.result().numpy()
if step % 100 == 0:
print(step, 'loss:', loss_meter.result().numpy())
loss_meter.reset_states()
# ...
if step % 500 == 0:
total, total_correct = 0., 0
acc_meter.reset_states()
Compile
Fit
Evaluate
Predict
with tf.GradientTape() as tape:
x = tf.reshape(x, (-1, 28*28))
out = network(x)
y_onehot = tf.one_hot(y, depth=10)
loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True))
grads = tape.gradient(loss, network.trainable_variables)
optimizer.apply_gradients(zip(grads, network.trainable_variables))
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy'])
for epoch in range(epochs):
for step, (x, y) in enumerate(db):
# ...
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy'])
network.fit(db, epochs=10)
if step % 500 == 0:
total, total_correct = 0., 0
for step, (x, y) in enumerate(ds_val):
x = tf.reshape(x, (-1, 28*28))
out = network(x)
pred = tf.argmax(out, axis=1)
pred = tf.cast(pred, dtype=tf.int32)
correct = tf.equal(pred, y)
total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy()
total += x.shape[0]
print(step, 'Evaluate Acc:', total_correct/total)
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy'])
# validation_freq=2表示2个epochs做一次验证
network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2)
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(fromlogits=True),
metircs=['accuracy'])
# validation_freq=2表示2个epochs做一次验证
network.fit(db, epochs=10, validation_data=ds_val, validation_freq=2)
network.evaluate(ds_val)
sample = next(iter(ds_val))
x = sample[0]
y = sample[1]
pred = network.predict(x)
y = tf.argmax(y, axis=1)
pred = tf.argmax(pre, axis=1)
print(pred)
print(y)
标签:keras git -o sum ada its date nump step
原文地址:https://www.cnblogs.com/nickchen121/p/10921824.html