diff --git a/APP_Framework/Applications/knowing_app/iris_ml_demo/LogisticRegressionModel.h b/APP_Framework/Applications/knowing_app/iris_ml_demo/LogisticRegressionModel.h new file mode 100644 index 00000000..6d7bbfa5 --- /dev/null +++ b/APP_Framework/Applications/knowing_app/iris_ml_demo/LogisticRegressionModel.h @@ -0,0 +1,41 @@ +#include + +/** + * Compute dot product + */ +float dot(float *x, ...) +{ + va_list w; + va_start(w, 4); + float dot = 0.0; + + for (int i = 0; i < 4; i++) { + const float wi = va_arg(w, double); + dot += x[i] * wi; + } + + return dot; +} + +/** + * Predict class for features vector + */ +int predict(float *x) +{ + float votes[3] = {0.0f}; + votes[0] = dot(x, -0.423405592418, 0.967388282125, -2.517050233286, -1.079182996654) + 9.84868307535428; + votes[1] = dot(x, 0.534517184386, -0.321908835083, -0.206465997471, -0.944448257908) + 2.238120068472271; + votes[2] = dot(x, -0.111111591968, -0.645479447042, 2.723516230758, 2.023631254562) + -12.086803143826813; + // return argmax of votes + int classIdx = 0; + float maxVotes = votes[0]; + + for (int i = 1; i < 3; i++) { + if (votes[i] > maxVotes) { + classIdx = i; + maxVotes = votes[i]; + } + } + + return classIdx; +} \ No newline at end of file diff --git a/APP_Framework/Applications/knowing_app/iris_ml_demo/iris_ml_demo.c b/APP_Framework/Applications/knowing_app/iris_ml_demo/iris_ml_demo.c index d901f18c..001af67e 100644 --- a/APP_Framework/Applications/knowing_app/iris_ml_demo/iris_ml_demo.c +++ b/APP_Framework/Applications/knowing_app/iris_ml_demo/iris_ml_demo.c @@ -76,3 +76,23 @@ void iris_DecisonTree_predict() #ifdef __RT_THREAD_H__ MSH_CMD_EXPORT(iris_DecisonTree_predict, iris predict by decison tree classifier); #endif + +void iris_LogisticRegression_predict() +{ +#include "LogisticRegressionModel.h" + int result; + + simple_CSV_read(); + + for (int i = 0; i < data_len; i++) { + result = predict(data[i]); + printf("data %d: ", i + 1); + for (int j = 0; j < FEATURE_NUM; j++) { + printf("%.4f ", data[i][j]); + } + printf("result: %d\n", result); + } +} +#ifdef __RT_THREAD_H__ +MSH_CMD_EXPORT(iris_LogisticRegression_predict, iris predict by logistic regression); +#endif