标签:Dimension opp ecif __name__ gen rac AC github cut
numpy.
argmax
(a, axis=None, out=None)[source]Returns the indices of the maximum values along an axis.
Parameters: |
a : array_like
axis : int, optional
out : array, optional
|
---|---|
Returns: |
index_array : ndarray of ints
|
See also
amax
unravel_index
Notes
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
[3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b) # Only the first occurrence is returned.
1
在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别
if __name__ == "__main__": width, height = 32, 32 X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height)) trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666) print("sample data:") print(trainX[0]) print(trainY[0]) print(testX[-1]) print(testY[-1]) model = get_model(width, height, classes=100) filename = ‘cnn_handwrite-acc0.8.tflearn‘ # try to load model and resume training #try: # model.load(filename) # print("Model loaded OK. Resume training!") #except: # pass # Initialize our callback with desired accuracy threshold. early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.6) try: model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True, snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch. show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id=‘cnn_handwrite‘) except StopIteration as e: print("OK, stop iterate!Good!") model.save(filename) # predict all data and calculate confusion_matrix model.load(filename) pro_arr =model.predict(X) predict_labels = np.argmax(pro_arr, axis=1) print(classification_report(org_labels, predict_labels)) print(confusion_matrix(org_labels, predict_labels))
标签:Dimension opp ecif __name__ gen rac AC github cut
原文地址:https://www.cnblogs.com/bonelee/p/8976380.html