mcts-demo: refactor: state 改名为 game

并删除无用的已注释的代码
This commit is contained in:
Calcitem 2020-02-03 23:22:18 +08:00
parent 8a4e44c94e
commit ac88c9bdd0
2 changed files with 18 additions and 37 deletions

View File

@ -69,8 +69,6 @@ void MCTSGame::generateMoves(Stack<move_t, 8> &moves) const
return; return;
} }
//moves.reserve(numCols);
for (int col = 0; col < numCols; ++col) { for (int col = 0; col < numCols; ++col) {
if (board[0][col] == playerMarkers[0]) { if (board[0][col] == playerMarkers[0]) {
moves.push_back(col); moves.push_back(col);
@ -335,14 +333,9 @@ string Node::indentString(int indent) const
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
#if 0 Node *computeTree(const MCTSGame game,
unique_ptr<Node> computeTree(const MCTSGame rootState, const MCTSOptions options,
const MCTSOptions options, mt19937_64::result_type initialSeed)
mt19937_64::result_type initialSeed)
#endif
Node *computeTree(const MCTSGame rootState,
const MCTSOptions options,
mt19937_64::result_type initialSeed)
{ {
mt19937_64 random_engine(initialSeed); mt19937_64 random_engine(initialSeed);
@ -355,10 +348,9 @@ unique_ptr<Node> computeTree(const MCTSGame rootState,
} }
// Will support more players later. // Will support more players later.
assert(rootState.sideToMove == 1 || rootState.sideToMove == 2); assert(game.sideToMove == 1 || game.sideToMove == 2);
// auto root = unique_ptr<Node>(new Node(rootState)); Node *root = new Node(game);
Node *root = new Node(rootState);
#ifdef USE_OPENMP #ifdef USE_OPENMP
double start_time = ::omp_get_wtime(); double start_time = ::omp_get_wtime();
@ -369,31 +361,31 @@ unique_ptr<Node> computeTree(const MCTSGame rootState,
//auto node = root.get(); //auto node = root.get();
Node *node = root; Node *node = root;
MCTSGame game = rootState; MCTSGame tempGame = game;
// Select a path through the tree to a leaf node. // Select a path through the tree to a leaf node.
while (!node->hasUntriedMoves() && node->hasChildren()) { while (!node->hasUntriedMoves() && node->hasChildren()) {
node = node->selectChildUCT(); node = node->selectChildUCT();
game.doMove(node->move); tempGame.doMove(node->move);
} }
// If we are not already at the final game, expand the // If we are not already at the final game, expand the
// tree with a new node and move there. // tree with a new node and move there.
if (node->hasUntriedMoves()) { if (node->hasUntriedMoves()) {
auto move = node->getUntriedMove(&random_engine); auto move = node->getUntriedMove(&random_engine);
game.doMove(move); tempGame.doMove(move);
node = node->addChild(move, game); node = node->addChild(move, tempGame);
} }
// We now play randomly until the game ends. // We now play randomly until the game ends.
while (game.hasMoves()) { while (tempGame.hasMoves()) {
game.doRandomMove(&random_engine); tempGame.doRandomMove(&random_engine);
} }
// We have now reached a final game. Backpropagate the result // We have now reached a final game. Backpropagate the result
// up the tree to the root node. // up the tree to the root node.
while (node != nullptr) { while (node != nullptr) {
node->update(game.getResult(node->sideToMove)); node->update(tempGame.getResult(node->sideToMove));
node = node->parent; node = node->parent;
} }
@ -415,14 +407,14 @@ unique_ptr<Node> computeTree(const MCTSGame rootState,
return root; return root;
} }
move_t computeMove(const MCTSGame rootState, move_t computeMove(const MCTSGame game,
const MCTSOptions options) const MCTSOptions options)
{ {
// Will support more players later. // Will support more players later.
assert(rootState.sideToMove == 1 || rootState.sideToMove == 2); assert(game.sideToMove == 1 || game.sideToMove == 2);
Stack<move_t, 8> moves; Stack<move_t, 8> moves;
rootState.generateMoves(moves); game.generateMoves(moves);
assert(moves.size() > 0); assert(moves.size() > 0);
if (moves.size() == 1) { if (moves.size() == 1) {
return moves[0]; return moves[0];
@ -433,33 +425,23 @@ move_t computeMove(const MCTSGame rootState,
#endif #endif
// Start all jobs to compute trees. // Start all jobs to compute trees.
//vector<future<unique_ptr<Node>>> rootFutures;
future<Node *> rootFutures[THREADS_COUNT]; future<Node *> rootFutures[THREADS_COUNT];
MCTSOptions jobOptions = options; MCTSOptions jobOptions = options;
jobOptions.verbose = false; jobOptions.verbose = false;
for (int t = 0; t < options.nThreads; ++t) { for (int t = 0; t < options.nThreads; ++t) {
#if 0 auto func = [t, &game, &jobOptions]() -> Node* {
auto func = [t, &rootState, &jobOptions]() -> unique_ptr<Node> { return computeTree(game, jobOptions, 1012411 * t + 12515);
return computeTree(rootState, jobOptions, 1012411 * t + 12515);
};
#endif
auto func = [t, &rootState, &jobOptions]() -> Node* {
return computeTree(rootState, jobOptions, 1012411 * t + 12515);
}; };
//rootFutures.push_back(async(launch::async, func));
rootFutures[t] = async(launch::async, func); rootFutures[t] = async(launch::async, func);
} }
// Collect the results. // Collect the results.
//vector<unique_ptr<Node>> roots;
Node *roots[THREADS_COUNT] = { nullptr }; Node *roots[THREADS_COUNT] = { nullptr };
for (int t = 0; t < options.nThreads; ++t) { for (int t = 0; t < options.nThreads; ++t) {
//roots.push_back(move(rootFutures[t].get()));
roots[t] = move(rootFutures[t].get()); roots[t] = move(rootFutures[t].get());
} }
@ -469,7 +451,6 @@ move_t computeMove(const MCTSGame rootState,
long long gamesPlayed = 0; long long gamesPlayed = 0;
for (int t = 0; t < options.nThreads; ++t) { for (int t = 0; t < options.nThreads; ++t) {
//auto root = roots[t].get();
Node *root = roots[t]; Node *root = roots[t];
gamesPlayed += root->visits; gamesPlayed += root->visits;

View File

@ -198,7 +198,7 @@ private:
double scoreUCT { 0 }; double scoreUCT { 0 };
}; };
move_t computeMove(const MCTSGame rootState, move_t computeMove(const MCTSGame game,
const MCTSOptions options = MCTSOptions()); const MCTSOptions options = MCTSOptions());