diff --git a/src/ai/mcts.cpp b/src/ai/mcts.cpp index fcfa5924..5a1c797b 100644 --- a/src/ai/mcts.cpp +++ b/src/ai/mcts.cpp @@ -11,6 +11,7 @@ #include "mcts.h" #include "position.h" +#include "search.h" #ifdef MCTS_AI @@ -231,7 +232,7 @@ string Node::indentString(int indent) const ///////////////////////////////////////////////////////// ///////////////////////////////////////////////////////// -Node *computeTree(Game game, +Node *AIAlgorithm::computeTree(Game game, const MCTSOptions options, mt19937_64::result_type initialSeed) { @@ -305,7 +306,7 @@ Node *computeTree(Game game, return root; } -move_t computeMove(Game game, +move_t AIAlgorithm::computeMove(Game game, const MCTSOptions options) { // Will support more players later. @@ -329,7 +330,7 @@ move_t computeMove(Game game, jobOptions.verbose = true; for (int t = 0; t < options.nThreads; ++t) { - auto func = [t, &game, &jobOptions]() -> Node* { + auto func = [t, &game, &jobOptions, this]() -> Node* { return computeTree(game, jobOptions, 1012411 * t + 12515); }; @@ -420,71 +421,4 @@ ostream &operator << (ostream &out, Game &game) return out; } -/////////////////////////////////////////////////////////////////////////////////////////////////////// - -void runMCTS() -{ - bool humanPlayer = false; - - MCTSOptions optionsPlayer1, optionsPlayer2; - -#ifdef _DEBUG - optionsPlayer1.maxIterations = 10; - optionsPlayer1.verbose = true; - optionsPlayer2.maxIterations = 10; - optionsPlayer2.verbose = true; -#else - optionsPlayer1.maxIterations = 2000000; - optionsPlayer1.verbose = true; - optionsPlayer2.maxIterations = 2000000; - optionsPlayer2.verbose = true; -#endif - - Game game; - - while (game.hasMoves()) { - cout << endl << "State: " << game << endl; - - move_t move = MOVE_NONE; - if (game.position->sideToMove == PLAYER_BLACK) { - move = computeMove(game, optionsPlayer1); - game.doMove(move); - } else { - move = computeMove(game, optionsPlayer2); - game.doMove(move); - } - } - - cout << endl << "Final game: " << game << endl; - - if (game.getResult(PLAYER_WHITE) == 1.0) { - cout << "Player 1 wins!" << endl; - } else if (game.getResult(PLAYER_BLACK) == 1.0) { - cout << "Player 2 wins!" << endl; - } else { - cout << "Nobody wins!" << endl; - } -} - -int mcts_main() -{ - std::chrono::milliseconds::rep timeStart = std::chrono::duration_cast - (std::chrono::steady_clock::now().time_since_epoch()).count(); - - try { - runMCTS(); - } catch (runtime_error & error) { - cerr << "ERROR: " << error.what() << endl; - return 1; - } - - std::chrono::milliseconds::rep timeEnd = std::chrono::duration_cast - (std::chrono::steady_clock::now().time_since_epoch()).count(); - - std::chrono::milliseconds::rep totalTime = (timeEnd - timeStart); - - loggerDebug("\nTotal Time: %llums\n", totalTime); - - return 0; -} #endif // MCTS_AI diff --git a/src/ai/search.h b/src/ai/search.h index 0684c676..4c83f61d 100644 --- a/src/ai/search.h +++ b/src/ai/search.h @@ -44,6 +44,10 @@ #include "stopwatch.h" #endif +#ifdef MCTS_AI +#include "mcts.h" +#endif + class AIAlgorithm; class Game; class Node; @@ -199,6 +203,15 @@ public: static int nodeCompare(const Node *first, const Node *second); #endif // ALPHABETA_AI +#ifdef MCTS_AI + // TODO: 分离到 MCTS 算法类 + Node *computeTree(Game game, + const MCTSOptions options, + mt19937_64::result_type initialSeed); + move_t AIAlgorithm::computeMove(Game game, + const MCTSOptions options); +#endif + #ifdef ENDGAME_LEARNING bool findEndgameHash(hash_t hash, Endgame &endgame); static int recordEndgameHash(hash_t hash, const Endgame &endgame); diff --git a/src/base/aithread.cpp b/src/base/aithread.cpp index 775649f1..9af8074c 100644 --- a/src/base/aithread.cpp +++ b/src/base/aithread.cpp @@ -127,7 +127,7 @@ void AiThread::run() #ifdef MCTS_AI MCTSOptions mctsOptions; - move_t move = computeMove(*game, mctsOptions); + move_t move = ai.computeMove(*game, mctsOptions); strCommand = ai.moveToCommand(move); emitCommand();