mcts: refactor: computeTree 和 computeMove 改放在 AIAlgorithm 类中
This commit is contained in:
parent
1e6456b902
commit
a61df6e7d6
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "mcts.h"
|
||||
#include "position.h"
|
||||
#include "search.h"
|
||||
|
||||
#ifdef MCTS_AI
|
||||
|
||||
|
@ -231,7 +232,7 @@ string Node::indentString(int indent) const
|
|||
/////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////
|
||||
|
||||
Node *computeTree(Game game,
|
||||
Node *AIAlgorithm::computeTree(Game game,
|
||||
const MCTSOptions options,
|
||||
mt19937_64::result_type initialSeed)
|
||||
{
|
||||
|
@ -305,7 +306,7 @@ Node *computeTree(Game game,
|
|||
return root;
|
||||
}
|
||||
|
||||
move_t computeMove(Game game,
|
||||
move_t AIAlgorithm::computeMove(Game game,
|
||||
const MCTSOptions options)
|
||||
{
|
||||
// Will support more players later.
|
||||
|
@ -329,7 +330,7 @@ move_t computeMove(Game game,
|
|||
jobOptions.verbose = true;
|
||||
|
||||
for (int t = 0; t < options.nThreads; ++t) {
|
||||
auto func = [t, &game, &jobOptions]() -> Node* {
|
||||
auto func = [t, &game, &jobOptions, this]() -> Node* {
|
||||
return computeTree(game, jobOptions, 1012411 * t + 12515);
|
||||
};
|
||||
|
||||
|
@ -420,71 +421,4 @@ ostream &operator << (ostream &out, Game &game)
|
|||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void runMCTS()
|
||||
{
|
||||
bool humanPlayer = false;
|
||||
|
||||
MCTSOptions optionsPlayer1, optionsPlayer2;
|
||||
|
||||
#ifdef _DEBUG
|
||||
optionsPlayer1.maxIterations = 10;
|
||||
optionsPlayer1.verbose = true;
|
||||
optionsPlayer2.maxIterations = 10;
|
||||
optionsPlayer2.verbose = true;
|
||||
#else
|
||||
optionsPlayer1.maxIterations = 2000000;
|
||||
optionsPlayer1.verbose = true;
|
||||
optionsPlayer2.maxIterations = 2000000;
|
||||
optionsPlayer2.verbose = true;
|
||||
#endif
|
||||
|
||||
Game game;
|
||||
|
||||
while (game.hasMoves()) {
|
||||
cout << endl << "State: " << game << endl;
|
||||
|
||||
move_t move = MOVE_NONE;
|
||||
if (game.position->sideToMove == PLAYER_BLACK) {
|
||||
move = computeMove(game, optionsPlayer1);
|
||||
game.doMove(move);
|
||||
} else {
|
||||
move = computeMove(game, optionsPlayer2);
|
||||
game.doMove(move);
|
||||
}
|
||||
}
|
||||
|
||||
cout << endl << "Final game: " << game << endl;
|
||||
|
||||
if (game.getResult(PLAYER_WHITE) == 1.0) {
|
||||
cout << "Player 1 wins!" << endl;
|
||||
} else if (game.getResult(PLAYER_BLACK) == 1.0) {
|
||||
cout << "Player 2 wins!" << endl;
|
||||
} else {
|
||||
cout << "Nobody wins!" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
int mcts_main()
|
||||
{
|
||||
std::chrono::milliseconds::rep timeStart = std::chrono::duration_cast<std::chrono::milliseconds>
|
||||
(std::chrono::steady_clock::now().time_since_epoch()).count();
|
||||
|
||||
try {
|
||||
runMCTS();
|
||||
} catch (runtime_error & error) {
|
||||
cerr << "ERROR: " << error.what() << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::chrono::milliseconds::rep timeEnd = std::chrono::duration_cast<std::chrono::milliseconds>
|
||||
(std::chrono::steady_clock::now().time_since_epoch()).count();
|
||||
|
||||
std::chrono::milliseconds::rep totalTime = (timeEnd - timeStart);
|
||||
|
||||
loggerDebug("\nTotal Time: %llums\n", totalTime);
|
||||
|
||||
return 0;
|
||||
}
|
||||
#endif // MCTS_AI
|
||||
|
|
|
@ -44,6 +44,10 @@
|
|||
#include "stopwatch.h"
|
||||
#endif
|
||||
|
||||
#ifdef MCTS_AI
|
||||
#include "mcts.h"
|
||||
#endif
|
||||
|
||||
class AIAlgorithm;
|
||||
class Game;
|
||||
class Node;
|
||||
|
@ -199,6 +203,15 @@ public:
|
|||
static int nodeCompare(const Node *first, const Node *second);
|
||||
#endif // ALPHABETA_AI
|
||||
|
||||
#ifdef MCTS_AI
|
||||
// TODO: 分离到 MCTS 算法类
|
||||
Node *computeTree(Game game,
|
||||
const MCTSOptions options,
|
||||
mt19937_64::result_type initialSeed);
|
||||
move_t AIAlgorithm::computeMove(Game game,
|
||||
const MCTSOptions options);
|
||||
#endif
|
||||
|
||||
#ifdef ENDGAME_LEARNING
|
||||
bool findEndgameHash(hash_t hash, Endgame &endgame);
|
||||
static int recordEndgameHash(hash_t hash, const Endgame &endgame);
|
||||
|
|
|
@ -127,7 +127,7 @@ void AiThread::run()
|
|||
#ifdef MCTS_AI
|
||||
MCTSOptions mctsOptions;
|
||||
|
||||
move_t move = computeMove(*game, mctsOptions);
|
||||
move_t move = ai.computeMove(*game, mctsOptions);
|
||||
|
||||
strCommand = ai.moveToCommand(move);
|
||||
emitCommand();
|
||||
|
|
Loading…
Reference in New Issue