OpenCV3.3中 K-最近邻法(KNN)接口简介及使用
OpenCV 3.3中給出了K-最近鄰(KNN)算法的實現,即cv::ml::Knearest類,此類的聲明在include/opecv2/ml.hpp文件中,實現在modules/ml/src/knearest.cpp文件中。其中:
(1)、cv::ml::Knearest類:繼承自cv::ml::StateModel,而cv::ml::StateModel又繼承自cv::Algorithm;
(2)、create函數:為static,new一個KNearestImpl用來創建一個KNearest對象;
(3)、setDefaultK/getDefaultK函數:在預測時,設置/獲取的K值;
(4)、setIsClassifier/getIsClassifier函數:設置/獲取應用KNN是進行分類還是回歸;
(5)、setEmax/getEmax函數:在使用KDTree算法時,設置/獲取Emax參數值;
(6)、setAlgorithmType/getAlgorithmType函數:設置/獲取KNN算法類型,目前支持兩種:brute_force和KDTree;
(7)、findNearest函數:根據輸入預測分類/回歸結果。
關于KNN算法介紹可以參考:?http://blog.csdn.net/fengbingchun/article/details/78464169 ?
以下是從數據集MNIST中提取的40幅圖像,0,1,2,3四類各20張,每類的前10幅來自于訓練樣本,用于訓練,后10幅來自測試樣本,用于測試,如下圖:
關于MNIST的介紹可以參考:? http://blog.csdn.net/fengbingchun/article/details/49611549?
測試代碼如下:
#include "opencv.hpp"
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <opencv2/opencv.hpp>
#include <opencv2/ml.hpp>
#include "common.hpp"/// K-Nearest Neighbor(KNN) //
int test_opencv_knn_predict()
{const int K{ 3 };cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::create();knn->setDefaultK(K);knn->setIsClassifier(true);knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE);const std::string image_path{"E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/"};cv::Mat tmp = cv::imread(image_path + "0_1.jpg", 0);const int train_samples_number{ 40 }, predict_samples_number{ 40 };const int every_class_number{ 10 };cv::Mat train_data(train_samples_number, tmp.rows * tmp.cols, CV_32FC1);cv::Mat train_labels(train_samples_number, 1, CV_32FC1);float* p = (float*)train_labels.data;for (int i = 0; i < 4; ++i) {std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });}// train datafor (int i = 0; i < 4; ++i) {static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };static const std::string suffix{ ".jpg" };for (int j = 1; j <= every_class_number; ++j) {std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;cv::Mat image = cv::imread(image_name, 0);CHECK(!image.empty() && image.isContinuous());image.convertTo(image, CV_32FC1);image = image.reshape(0, 1);tmp = train_data.rowRange(i * every_class_number + j - 1, i * every_class_number + j);image.copyTo(tmp);}}knn->train(train_data, cv::ml::ROW_SAMPLE, train_labels);// predict dattacv::Mat predict_data(predict_samples_number, tmp.rows * tmp.cols, CV_32FC1);for (int i = 0; i < 4; ++i) {static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };static const std::string suffix{ ".jpg" };for (int j = 11; j <= every_class_number+10; ++j) {std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;cv::Mat image = cv::imread(image_name, 0);CHECK(!image.empty() && image.isContinuous());image.convertTo(image, CV_32FC1);image = image.reshape(0, 1);tmp = predict_data.rowRange(i * every_class_number + j - 10 - 1, i * every_class_number + j - 10);image.copyTo(tmp);}}cv::Mat result;knn->findNearest(predict_data, K, result);CHECK(result.rows == predict_samples_number);cv::Mat predict_labels(predict_samples_number, 1, CV_32FC1);p = (float*)predict_labels.data;for (int i = 0; i < 4; ++i) {std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });}int count{ 0 };for (int i = 0; i < predict_samples_number; ++i) {float value1 = ((float*)predict_labels.data)[i];float value2 = ((float*)result.data)[i];fprintf(stdout, "expected value: %f, actual value: %f\n", value1, value2);if (int(value1) == int(value2)) ++count;}fprintf(stdout, "when K = %d, accuracy: %f\n", K, count * 1.f / predict_samples_number);return 0;
}
測試結果如下:
GitHub:?https://github.com/fengbingchun/NN_Test ??
總結
以上是生活随笔為你收集整理的OpenCV3.3中 K-最近邻法(KNN)接口简介及使用的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: K-最近邻法(KNN)简介
- 下一篇: Brute Force算法介绍及C++实