mcts: MCTS 在本游戏中可用

开启 MCTS_AI 宏即可。
This commit is contained in:
Calcitem 2020-02-26 02:02:39 +08:00
parent 9fe69de293
commit ac18becf93
8 changed files with 144 additions and 316 deletions

View File

@ -34,6 +34,8 @@
//#define UCT_DEMO
//#define MCTS_AI
#ifdef TEST_MODE
#define DONOT_PLAY_SOUND
#endif // TEST_MODE

View File

@ -1,6 +1,7 @@
//
// Petter Strandmark 2013
// petter.strandmark@gmail.com
// calcitem@outlook.com
//
// Monte Carlo Tree Search for finite games.
//
@ -8,195 +9,92 @@
// http://mcts.ai/code/python.html
//
#include <chrono>
#include "mcts.h"
#include "position.h"
void MCTSGame::doMove(move_t move)
{
assert(0 <= move && move < numCols);
assert(board[0][move] == playerMarkers[0]);
checkInvariant();
#ifdef MCTS_AI
int row = numRows - 1;
while (board[row][move] != playerMarkers[0]) row--;
board[row][move] = playerMarkers[sideToMove];
lastCol = move;
lastRow = row;
sideToMove = 3 - sideToMove;
}
template<typename RandomEngine>
void MCTSGame::doRandomMove(RandomEngine *engine)
void Game::doRandomMove(Node *node, mt19937_64 *engine)
{
assert(hasMoves());
checkInvariant();
generateMoves(moves);
uniform_int_distribution<move_t> moves(0, numCols - 1);
while (true) {
auto move = moves(*engine);
if (board[0][move] == playerMarkers[0]) {
doMove(move);
return;
}
}
int movesSize = moves.size();
uniform_int_distribution<int> index(0, movesSize - 1);
auto i = index(*engine);
move_t m = moves[i];
doMove(m);
}
bool MCTSGame::hasMoves() const
bool Game::hasMoves() const
{
checkInvariant();
char winner = getWinner();
if (winner != playerMarkers[0]) {
player_t winner = getWinner();
if (winner != PLAYER_NOBODY) {
return false;
}
for (int col = 0; col < numCols; ++col) {
if (board[0][col] == playerMarkers[0]) {
return true;
}
}
return false;
return true;
}
void MCTSGame::generateMoves(Stack<move_t, 8> &moves) const
{
checkInvariant();
if (getWinner() != playerMarkers[0]) {
return;
}
for (int col = 0; col < numCols; ++col) {
if (board[0][col] == playerMarkers[0]) {
moves.push_back(col);
}
}
}
char MCTSGame::getWinner() const
{
if (lastCol < 0) {
return playerMarkers[0];
}
// We only need to check around the last piece played.
auto piece = board[lastRow][lastCol];
// X X X X
int left = 0, right = 0;
for (int col = lastCol - 1; col >= 0 && board[lastRow][col] == piece; --col) left++;
for (int col = lastCol + 1; col < numCols && board[lastRow][col] == piece; ++col) right++;
if (left + 1 + right >= 4) {
return piece;
}
// X
// X
// X
// X
int up = 0, down = 0;
for (int row = lastRow - 1; row >= 0 && board[row][lastCol] == piece; --row) up++;
for (int row = lastRow + 1; row < numRows && board[row][lastCol] == piece; ++row) down++;
if (up + 1 + down >= 4) {
return piece;
}
// X
// X
// X
// X
up = 0;
down = 0;
for (int row = lastRow - 1, col = lastCol - 1; row >= 0 && col >= 0 && board[row][col] == piece; --row, --col) up++;
for (int row = lastRow + 1, col = lastCol + 1; row < numRows && col < numCols && board[row][col] == piece; ++row, ++col) down++;
if (up + 1 + down >= 4) {
return piece;
}
// X
// X
// X
// X
up = 0;
down = 0;
for (int row = lastRow + 1, col = lastCol - 1; row < numRows && col >= 0 && board[row][col] == piece; ++row, --col) up++;
for (int row = lastRow - 1, col = lastCol + 1; row >= 0 && col < numCols && board[row][col] == piece; --row, ++col) down++;
if (up + 1 + down >= 4) {
return piece;
}
return playerMarkers[0];
}
double MCTSGame::getResult(int currentSideToMove) const
double Game::getResult(player_t currentSideToMove) const
{
assert(!hasMoves());
checkInvariant();
auto winner = getWinner();
if (winner == playerMarkers[0]) {
if (winner == PLAYER_NOBODY) {
return 0.5;
}
if (winner == playerMarkers[currentSideToMove]) {
if (winner == currentSideToMove) {
return 0.0;
} else {
return 1.0;
}
}
void MCTSGame::print(ostream &out) const
void Game::checkInvariant() const
{
out << endl;
out << " ";
for (int col = 0; col < numCols - 1; ++col) {
out << col << ' ';
}
out << numCols - 1 << endl;
for (int row = 0; row < numRows; ++row) {
out << "|";
for (int col = 0; col < numCols - 1; ++col) {
out << board[row][col] << ' ';
}
out << board[row][numCols - 1] << "|" << endl;
}
out << "+";
for (int col = 0; col < numCols - 1; ++col) {
out << "--";
}
out << "-+" << endl;
out << playerMarkers[sideToMove] << " to move " << endl << endl;
}
void MCTSGame::checkInvariant() const
{
assert(sideToMove == 1 || sideToMove == 2);
assert(position->sideToMove == PLAYER_BLACK || position->sideToMove == PLAYER_WHITE);
}
////////////////////////////////////////////////////////////////////////////////////////
Node::Node(const MCTSGame &game) :
sideToMove(game.sideToMove)
Node::Node(Game &game) :
sideToMove(game.position->sideToMove)
{
game.generateMoves(moves);
}
Node::Node(const MCTSGame &game, const move_t &m, Node *p) :
Node::Node(Game &game, const move_t &m, Node *p) :
move(m),
parent(p),
sideToMove(game.sideToMove)
sideToMove(game.position->sideToMove)
{
game.generateMoves(moves);
}
void deleteChild(Node *node)
{
for (int i = 0; i < node->childrenSize; i++) {
deleteChild(node->children[i]);
}
node->childrenSize = 0;
delete node;
node = nullptr;
}
Node::~Node()
{
for (int i = 0; i < childrenSize; i++) {
delete children[i];
deleteChild(children[i]);
}
}
@ -244,7 +142,7 @@ Node *Node::selectChildUCT() const
for (int i = 0; i < childrenSize; i++) {
children[i]->scoreUCT = double(children[i]->wins) / double(children[i]->visits) +
sqrt(2.0 * log(double(this->visits)) / children[i]->visits);
sqrt(2.0 * log(double(this->visits)) / children[i]->visits);
}
double scoreMax = numeric_limits<double>::min();
@ -260,7 +158,7 @@ Node *Node::selectChildUCT() const
return nodeMax;
}
Node *Node::addChild(const move_t &move, const MCTSGame &game)
Node *Node::addChild(const move_t &move, Game &game)
{
auto node = new Node(game, move, this); // TODO: memmgr_alloc
@ -333,7 +231,7 @@ string Node::indentString(int indent) const
/////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////
Node *computeTree(const MCTSGame game,
Node *computeTree(Game game,
const MCTSOptions options,
mt19937_64::result_type initialSeed)
{
@ -348,7 +246,7 @@ Node *computeTree(const MCTSGame game,
}
// Will support more players later.
assert(game.sideToMove == 1 || game.sideToMove == 2);
assert(game.position->sideToMove == PLAYER_BLACK || game.position->sideToMove == PLAYER_WHITE);
Node *root = new Node(game);
@ -361,7 +259,7 @@ Node *computeTree(const MCTSGame game,
//auto node = root.get();
Node *node = root;
MCTSGame tempGame = game;
Game tempGame = game;
// Select a path through the tree to a leaf node.
while (!node->hasUntriedMoves() && node->hasChildren()) {
@ -379,7 +277,7 @@ Node *computeTree(const MCTSGame game,
// We now play randomly until the game ends.
while (tempGame.hasMoves()) {
tempGame.doRandomMove(&random_engine);
tempGame.doRandomMove(root, &random_engine);
}
// We have now reached a final game. Backpropagate the result
@ -407,13 +305,13 @@ Node *computeTree(const MCTSGame game,
return root;
}
move_t computeMove(const MCTSGame game,
move_t computeMove(Game game,
const MCTSOptions options)
{
// Will support more players later.
assert(game.sideToMove == 1 || game.sideToMove == 2);
assert(game.position->sideToMove == PLAYER_BLACK || game.position->sideToMove == PLAYER_WHITE);
Stack<move_t, 8> moves;
Stack<move_t, MOVE_COUNT> moves;
game.generateMoves(moves);
assert(moves.size() > 0);
if (moves.size() == 1) {
@ -428,7 +326,7 @@ move_t computeMove(const MCTSGame game,
future<Node *> rootFutures[THREADS_COUNT];
MCTSOptions jobOptions = options;
jobOptions.verbose = false;
jobOptions.verbose = true;
for (int t = 0; t < options.nThreads; ++t) {
auto func = [t, &game, &jobOptions]() -> Node* {
@ -466,7 +364,7 @@ move_t computeMove(const MCTSGame game,
wins[root->children[i]->move] += root->children[i]->wins;
}
delete root;
deleteChild(root);
root = nullptr;
}
@ -516,27 +414,24 @@ move_t computeMove(const MCTSGame game,
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
ostream &operator << (ostream &out, const MCTSGame &game)
ostream &operator << (ostream &out, Game &game)
{
game.print(out);
//game.print(out);
return out;
}
const char MCTSGame::playerMarkers[3] = { '.', 'X', 'O' };
///////////////////////////////////////////////////////////////////////////////////////////////////////
void runConnectFour()
void runMCTS()
{
bool humanPlayer = false;
MCTSOptions optionsPlayer1, optionsPlayer2;
#ifdef _DEBUG
optionsPlayer1.maxIterations = 100000;
optionsPlayer1.maxIterations = 10;
optionsPlayer1.verbose = true;
optionsPlayer2.maxIterations = 100000;
optionsPlayer2.maxIterations = 10;
optionsPlayer2.verbose = true;
#else
optionsPlayer1.maxIterations = 2000000;
@ -545,54 +440,39 @@ void runConnectFour()
optionsPlayer2.verbose = true;
#endif
MCTSGame game;
Game game;
while (game.hasMoves()) {
cout << endl << "State: " << game << endl;
move_t move = MCTSGame::noMove;
if (game.sideToMove == 1) {
move_t move = MOVE_NONE;
if (game.position->sideToMove == PLAYER_BLACK) {
move = computeMove(game, optionsPlayer1);
game.doMove(move);
} else {
if (humanPlayer) {
while (true) {
cout << "Input your move: ";
move = MCTSGame::noMove;
cin >> move;
try {
game.doMove(move);
break;
} catch (exception &) {
cout << "Invalid move." << endl;
}
}
} else {
move = computeMove(game, optionsPlayer2);
game.doMove(move);
}
move = computeMove(game, optionsPlayer2);
game.doMove(move);
}
}
cout << endl << "Final game: " << game << endl;
if (game.getResult(2) == 1.0) {
if (game.getResult(PLAYER_WHITE) == 1.0) {
cout << "Player 1 wins!" << endl;
} else if (game.getResult(1) == 1.0) {
} else if (game.getResult(PLAYER_BLACK) == 1.0) {
cout << "Player 2 wins!" << endl;
} else {
cout << "Nobody wins!" << endl;
}
}
#ifdef UCT_DEMO
int main()
int mcts_main()
{
std::chrono::milliseconds::rep timeStart = std::chrono::duration_cast<std::chrono::milliseconds>
(std::chrono::steady_clock::now().time_since_epoch()).count();
try {
runConnectFour();
runMCTS();
} catch (runtime_error & error) {
cerr << "ERROR: " << error.what() << endl;
return 1;
@ -607,4 +487,4 @@ int main()
return 0;
}
#endif // UCT_DEMO
#endif // MCTS_AI

View File

@ -62,7 +62,7 @@ private:
#include "stack.h"
#ifdef _DEBUG
#ifdef _WIN32
#define USE_OPENMP
#endif
@ -74,132 +74,15 @@ private:
using namespace std;
typedef int move_t;
class MCTSGame
{
public:
typedef int move_t;
static const move_t noMove = -1;
static const char playerMarkers[3];
MCTSGame()
: sideToMove(1),
lastCol(-1),
lastRow(-1)
{
for (int r = 0; r < numRows; r++) {
for (int c = 0; c < numCols; c++) {
board[r][c] = playerMarkers[0];
}
}
}
void doMove(move_t move);
template<typename RandomEngine>
void doRandomMove(RandomEngine *engine);
bool hasMoves() const;
void generateMoves(Stack<move_t, 8> &moves) const;
char getWinner() const;
double getResult(int currentSideToMove) const;
void print(ostream &out) const;
int sideToMove;
private:
void checkInvariant() const;
static const int numRows = 6;
static const int numCols = 7;
char board[numRows][numCols];
int lastCol;
int lastRow;
};
static const int THREADS_COUNT = 2;
class MCTSOptions
{
public:
int nThreads { THREADS_COUNT };
int maxIterations { 10000 };
double maxTime { -1.0 };
bool verbose { false };
int nThreads { THREADS_COUNT };
int maxIterations { 40000000 };
double maxTime { 6000 };
bool verbose { true };
};
//
//
// [1] Chaslot, G. M. B., Winands, M. H., & van Den Herik, H. J. (2008).
// Parallel monte-carlo tree search. In Computers and Games (pp.
// 60-71). Springer Berlin Heidelberg.
//
//
// This class is used to build the game tree. The root is created by the users and
// the rest of the tree is created by add_node.
//
class Node
{
public:
Node(const MCTSGame &game);
~Node();
bool hasUntriedMoves() const;
template<typename RandomEngine>
move_t getUntriedMove(RandomEngine *engine) const;
Node *bestChildren() const;
bool hasChildren() const;
Node *selectChildUCT() const;
Node *addChild(const move_t &move, const MCTSGame &game);
void update(double result);
string toString();
string treeToString(int max_depth = 1000000, int indent = 0) const;
static const int NODE_CHILDREN_SIZE = 8;
const move_t move { MCTSGame::noMove };
Node *const parent {nullptr};
const int sideToMove;
//atomic<double> wins;
//atomic<int> visits;
double wins { 0 };
int visits { 0 };
Stack<move_t, 8> moves;
Node *children[NODE_CHILDREN_SIZE];
int childrenSize { 0 };
private:
Node(const MCTSGame &game, const move_t &move, Node *parent);
string indentString(int indent) const;
Node(const Node &);
Node &operator = (const Node &);
double scoreUCT { 0 };
};
move_t computeMove(const MCTSGame game,
const MCTSOptions options = MCTSOptions());
#endif // MCTS_HEADER_PETTER

View File

@ -75,6 +75,9 @@ void Game::generateMoves(Stack<move_t, MOVE_COUNT> &moves)
continue;
}
#ifdef MCTS_AI
moves.push_back((move_t)square);
#else // MCTS_AI
if (position->phase != PHASE_READY) {
moves.push_back((move_t)square);
} else {
@ -83,6 +86,7 @@ void Game::generateMoves(Stack<move_t, MOVE_COUNT> &moves)
moves.push_back((move_t)square);
}
}
#endif // MCTS_AI
}
break;
}

View File

@ -63,10 +63,12 @@ public:
Node();
~Node();
bool hasChildren() const
{
return (childrenSize != 0);
}
#ifdef MCTS_AI
Node(Game &game);
Node(Game &game, const move_t &move, Node *parent);
#endif // MCTS_AI
bool hasChildren() const;
Node *addChild(
const move_t &move,
@ -79,6 +81,29 @@ public:
static const int NODE_CHILDREN_SIZE = (4 * 4 + 3 * 4 * 2); // TODO: 缩减空间
#ifdef MCTS_AI
Stack<move_t, NODE_CHILDREN_SIZE> moves;
//atomic<double> wins;
//atomic<int> visits;
double wins { 0 };
int visits { 0 };
double scoreUCT { 0 };
bool hasUntriedMoves() const;
template<typename RandomEngine>
move_t getUntriedMove(RandomEngine *engine) const;
Node *bestChildren() const;
Node *selectChildUCT() const;
Node *addChild(const move_t &move, Game &game);
void update(double result);
string toString();
string treeToString(int max_depth = 1000000, int indent = 0) const;
string indentString(int indent) const;
#else
move_t moves[NODE_CHILDREN_SIZE];
#endif // MCTS_AI
move_t move { MOVE_NONE }; // 着法的命令行指令,图上标示为节点前的连线
value_t value { VALUE_UNKNOWN }; // 节点的值
rating_t rating { RATING_ZERO }; // 节点分数
@ -87,7 +112,6 @@ public:
bool pruned { false }; // 是否在此处剪枝
#endif
move_t moves[NODE_CHILDREN_SIZE];
Node *children[NODE_CHILDREN_SIZE];
int childrenSize { 0 };

View File

@ -92,6 +92,11 @@ void AiThread::emitCommand()
emit command(strCommand);
}
#ifdef MCTS_AI
move_t computeMove(Game game,
const MCTSOptions options);
#endif // MCTS_AI
void AiThread::run()
{
// 测试用数据
@ -119,6 +124,15 @@ void AiThread::run()
emit searchStarted();
mutex.unlock();
#ifdef MCTS_AI
MCTSOptions mctsOptions;
move_t move = computeMove(*game, mctsOptions);
strCommand = ai.moveToCommand(move);
emitCommand();
#else // MCTS_AI
if (ai.search(depth) == 3) {
// 三次重复局面和
loggerDebug("Draw\n\n");
@ -132,6 +146,8 @@ void AiThread::run()
}
}
#endif // MCTS_AI
emit searchFinished();
// 执行完毕后继续判断

View File

@ -836,6 +836,11 @@ bool Game::doMove(move_t m)
return false;
}
player_t Game::getWinner() const
{
return winner;
}
int Game::update()
{
int ret = -1;

View File

@ -29,6 +29,10 @@
#include "board.h"
#include "search.h"
#ifdef MCTS_AI
#include "mcts.h"
#endif
using namespace std;
class AIAlgorithm;
@ -138,12 +142,6 @@ public:
return position->action;
}
// 判断胜负
player_t getWinner() const
{
return winner;
}
// 玩家1或玩家2的用时
time_t getElapsedTime(int playerId);
@ -249,6 +247,9 @@ public:
// 着法生成
void generateMoves(Stack<move_t, MOVE_COUNT> &moves);
// 判断胜负
player_t getWinner() const;
// 下面几个函数没有算法无关判断和无关操作,节约算法时间
bool doMove(move_t move);
bool choose(square_t square);
@ -261,6 +262,19 @@ public:
hash_t updateHash(square_t square);
hash_t updateHashMisc();
#ifdef MCTS_AI
// MCTS 相关
Stack<move_t, MOVE_COUNT> moves;
//template<typename RandomEngine>
//void doRandomMove(RandomEngine *engine);
void doRandomMove(Node *node, mt19937_64 *engine);
bool hasMoves() const;
double getResult(player_t currentSideToMove) const;
void checkInvariant() const;
#endif // MCTS_AI
// 赢盘数
int score[COLOR_COUNT] = { 0 };
int score_draw { 0 };