mcts: refactor: computeTree 和 computeMove 改放在 AIAlgorithm 类中

This commit is contained in:
Calcitem 2020-02-29 08:40:18 +08:00
parent 1e6456b902
commit a61df6e7d6
3 changed files with 18 additions and 71 deletions

View File

@ -11,6 +11,7 @@
#include "mcts.h"
#include "position.h"
#include "search.h"
#ifdef MCTS_AI
@ -231,7 +232,7 @@ string Node::indentString(int indent) const
/////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////
Node *computeTree(Game game,
Node *AIAlgorithm::computeTree(Game game,
const MCTSOptions options,
mt19937_64::result_type initialSeed)
{
@ -305,7 +306,7 @@ Node *computeTree(Game game,
return root;
}
move_t computeMove(Game game,
move_t AIAlgorithm::computeMove(Game game,
const MCTSOptions options)
{
// Will support more players later.
@ -329,7 +330,7 @@ move_t computeMove(Game game,
jobOptions.verbose = true;
for (int t = 0; t < options.nThreads; ++t) {
auto func = [t, &game, &jobOptions]() -> Node* {
auto func = [t, &game, &jobOptions, this]() -> Node* {
return computeTree(game, jobOptions, 1012411 * t + 12515);
};
@ -420,71 +421,4 @@ ostream &operator << (ostream &out, Game &game)
return out;
}
///////////////////////////////////////////////////////////////////////////////////////////////////////
void runMCTS()
{
bool humanPlayer = false;
MCTSOptions optionsPlayer1, optionsPlayer2;
#ifdef _DEBUG
optionsPlayer1.maxIterations = 10;
optionsPlayer1.verbose = true;
optionsPlayer2.maxIterations = 10;
optionsPlayer2.verbose = true;
#else
optionsPlayer1.maxIterations = 2000000;
optionsPlayer1.verbose = true;
optionsPlayer2.maxIterations = 2000000;
optionsPlayer2.verbose = true;
#endif
Game game;
while (game.hasMoves()) {
cout << endl << "State: " << game << endl;
move_t move = MOVE_NONE;
if (game.position->sideToMove == PLAYER_BLACK) {
move = computeMove(game, optionsPlayer1);
game.doMove(move);
} else {
move = computeMove(game, optionsPlayer2);
game.doMove(move);
}
}
cout << endl << "Final game: " << game << endl;
if (game.getResult(PLAYER_WHITE) == 1.0) {
cout << "Player 1 wins!" << endl;
} else if (game.getResult(PLAYER_BLACK) == 1.0) {
cout << "Player 2 wins!" << endl;
} else {
cout << "Nobody wins!" << endl;
}
}
int mcts_main()
{
std::chrono::milliseconds::rep timeStart = std::chrono::duration_cast<std::chrono::milliseconds>
(std::chrono::steady_clock::now().time_since_epoch()).count();
try {
runMCTS();
} catch (runtime_error & error) {
cerr << "ERROR: " << error.what() << endl;
return 1;
}
std::chrono::milliseconds::rep timeEnd = std::chrono::duration_cast<std::chrono::milliseconds>
(std::chrono::steady_clock::now().time_since_epoch()).count();
std::chrono::milliseconds::rep totalTime = (timeEnd - timeStart);
loggerDebug("\nTotal Time: %llums\n", totalTime);
return 0;
}
#endif // MCTS_AI

View File

@ -44,6 +44,10 @@
#include "stopwatch.h"
#endif
#ifdef MCTS_AI
#include "mcts.h"
#endif
class AIAlgorithm;
class Game;
class Node;
@ -199,6 +203,15 @@ public:
static int nodeCompare(const Node *first, const Node *second);
#endif // ALPHABETA_AI
#ifdef MCTS_AI
// TODO: 分离到 MCTS 算法类
Node *computeTree(Game game,
const MCTSOptions options,
mt19937_64::result_type initialSeed);
move_t AIAlgorithm::computeMove(Game game,
const MCTSOptions options);
#endif
#ifdef ENDGAME_LEARNING
bool findEndgameHash(hash_t hash, Endgame &endgame);
static int recordEndgameHash(hash_t hash, const Endgame &endgame);

View File

@ -127,7 +127,7 @@ void AiThread::run()
#ifdef MCTS_AI
MCTSOptions mctsOptions;
move_t move = computeMove(*game, mctsOptions);
move_t move = ai.computeMove(*game, mctsOptions);
strCommand = ai.moveToCommand(move);
emitCommand();