From ac88c9bdd0e33014e4b9723273f34fc3d0646096 Mon Sep 17 00:00:00 2001 From: Calcitem Date: Mon, 3 Feb 2020 23:22:18 +0800 Subject: [PATCH] =?UTF-8?q?mcts-demo:=20refactor:=20state=20=E6=94=B9?= =?UTF-8?q?=E5=90=8D=E4=B8=BA=20game?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 并删除无用的已注释的代码 --- src/ai/mcts.cpp | 53 ++++++++++++++++--------------------------------- src/ai/mcts.h | 2 +- 2 files changed, 18 insertions(+), 37 deletions(-) diff --git a/src/ai/mcts.cpp b/src/ai/mcts.cpp index 485a2477..4ec2934c 100644 --- a/src/ai/mcts.cpp +++ b/src/ai/mcts.cpp @@ -69,8 +69,6 @@ void MCTSGame::generateMoves(Stack &moves) const return; } - //moves.reserve(numCols); - for (int col = 0; col < numCols; ++col) { if (board[0][col] == playerMarkers[0]) { moves.push_back(col); @@ -335,14 +333,9 @@ string Node::indentString(int indent) const ///////////////////////////////////////////////////////// ///////////////////////////////////////////////////////// -#if 0 -unique_ptr computeTree(const MCTSGame rootState, - const MCTSOptions options, - mt19937_64::result_type initialSeed) -#endif - Node *computeTree(const MCTSGame rootState, - const MCTSOptions options, - mt19937_64::result_type initialSeed) +Node *computeTree(const MCTSGame game, + const MCTSOptions options, + mt19937_64::result_type initialSeed) { mt19937_64 random_engine(initialSeed); @@ -355,10 +348,9 @@ unique_ptr computeTree(const MCTSGame rootState, } // Will support more players later. - assert(rootState.sideToMove == 1 || rootState.sideToMove == 2); + assert(game.sideToMove == 1 || game.sideToMove == 2); - // auto root = unique_ptr(new Node(rootState)); - Node *root = new Node(rootState); + Node *root = new Node(game); #ifdef USE_OPENMP double start_time = ::omp_get_wtime(); @@ -369,31 +361,31 @@ unique_ptr computeTree(const MCTSGame rootState, //auto node = root.get(); Node *node = root; - MCTSGame game = rootState; + MCTSGame tempGame = game; // Select a path through the tree to a leaf node. while (!node->hasUntriedMoves() && node->hasChildren()) { node = node->selectChildUCT(); - game.doMove(node->move); + tempGame.doMove(node->move); } // If we are not already at the final game, expand the // tree with a new node and move there. if (node->hasUntriedMoves()) { auto move = node->getUntriedMove(&random_engine); - game.doMove(move); - node = node->addChild(move, game); + tempGame.doMove(move); + node = node->addChild(move, tempGame); } // We now play randomly until the game ends. - while (game.hasMoves()) { - game.doRandomMove(&random_engine); + while (tempGame.hasMoves()) { + tempGame.doRandomMove(&random_engine); } // We have now reached a final game. Backpropagate the result // up the tree to the root node. while (node != nullptr) { - node->update(game.getResult(node->sideToMove)); + node->update(tempGame.getResult(node->sideToMove)); node = node->parent; } @@ -415,14 +407,14 @@ unique_ptr computeTree(const MCTSGame rootState, return root; } -move_t computeMove(const MCTSGame rootState, +move_t computeMove(const MCTSGame game, const MCTSOptions options) { // Will support more players later. - assert(rootState.sideToMove == 1 || rootState.sideToMove == 2); + assert(game.sideToMove == 1 || game.sideToMove == 2); Stack moves; - rootState.generateMoves(moves); + game.generateMoves(moves); assert(moves.size() > 0); if (moves.size() == 1) { return moves[0]; @@ -433,33 +425,23 @@ move_t computeMove(const MCTSGame rootState, #endif // Start all jobs to compute trees. - //vector>> rootFutures; future rootFutures[THREADS_COUNT]; MCTSOptions jobOptions = options; jobOptions.verbose = false; for (int t = 0; t < options.nThreads; ++t) { -#if 0 - auto func = [t, &rootState, &jobOptions]() -> unique_ptr { - return computeTree(rootState, jobOptions, 1012411 * t + 12515); - }; -#endif - - auto func = [t, &rootState, &jobOptions]() -> Node* { - return computeTree(rootState, jobOptions, 1012411 * t + 12515); + auto func = [t, &game, &jobOptions]() -> Node* { + return computeTree(game, jobOptions, 1012411 * t + 12515); }; - //rootFutures.push_back(async(launch::async, func)); rootFutures[t] = async(launch::async, func); } // Collect the results. - //vector> roots; Node *roots[THREADS_COUNT] = { nullptr }; for (int t = 0; t < options.nThreads; ++t) { - //roots.push_back(move(rootFutures[t].get())); roots[t] = move(rootFutures[t].get()); } @@ -469,7 +451,6 @@ move_t computeMove(const MCTSGame rootState, long long gamesPlayed = 0; for (int t = 0; t < options.nThreads; ++t) { - //auto root = roots[t].get(); Node *root = roots[t]; gamesPlayed += root->visits; diff --git a/src/ai/mcts.h b/src/ai/mcts.h index 999fe364..2fcd1570 100644 --- a/src/ai/mcts.h +++ b/src/ai/mcts.h @@ -198,7 +198,7 @@ private: double scoreUCT { 0 }; }; -move_t computeMove(const MCTSGame rootState, +move_t computeMove(const MCTSGame game, const MCTSOptions options = MCTSOptions());