用Monte Carlo tree search来解决倒水问题(Water Sort Puzzle)
https://ai-boson.github.io/mcts/
https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
之前我没了解过博弈树,蒙特卡洛搜索方面的知识。当时我觉得这个问题似乎分支不多的样子,就直接bfs暴力搜索,结果搜到30层左右就爆内存了(60g+的内存)。即使我加入了一些基于评分的剪枝和排序,还是无法突破35层。
现在用了mcts方法,真的是amazing!!!
这里需要注意的是,water sort puzzle 相比较于其他游戏,比如五子棋等等,存在循环搜索的问题。所以如果不加限制,可能会导致rollout陷入很长时间的循环。所以我在rollout加入了set来防止循环。
import numpy as np
from collections import defaultdict
from copy import copy,deepcopy
class State():
def __init__(self,state):
self.state = state
assert isinstance(self.state,list)
for item in state:
assert isinstance(item,list)
self.num_of_container = len(self.state)
cnt = 0
for item in state:
cnt += len(item)
assert cnt%4==0
self.num_of_color = cnt//4
def __eq__(self,other):
if not isinstance(other,State):
return False
return self.state==other.state
def __hash__(self):
r = 0
for container in self.state:
for color in container:
r += hash(color)*37
return r
@staticmethod
def same_color_in_container(container):
if len(container)==0:
return True
color = container[0]
for i in range(1,len(container)):
if container[i]!=color:
return False
return True
def __check(self,from_idx,to_idx):
if from_idx<0 or from_idx>=self.num_of_container or to_idx<0 or to_idx>=self.num_of_container or from_idx==to_idx:
return False
fv = self.state[from_idx]
tv = self.state[to_idx]
if len(fv)==0:
return False
if len(tv)==4:
return False
if len(tv)>0 and tv[-1]!=fv[-1]:
return False
if len(fv)==4 and fv[0]==fv[1]==fv[2]==fv[3]:
return False
if len(tv)==0 and State.same_color_in_container(fv):
return False
return True
def get_legal_actions(self):
actions = []
for i in range(self.num_of_container):
for j in range(self.num_of_container):
if self.__check(i,j):
actions.append((i,j))
return actions
def move(self,action):
ret = deepcopy(self)
from_idx = action[0]
to_idx = action[1]
fv = ret.state[from_idx]
tv = ret.state[to_idx]
c = fv[-1]
cnt = 1
i = len(fv)-2
while i>=0 and fv[i]==c:
i -= 1
cnt += 1
cnt = min(cnt,4-len(tv))
for i in range(cnt):
fv.pop()
tv.append(c)
return ret
def game_result(self):
for v in self.state:
if len(v)==0:
continue
if len(v)<4:
return -1
for i in range(3):
if v[i]!=v[i+1]:
return -1
return 1
def is_game_over(self):
return len(self.get_legal_actions())==0
def __str__(self):
s = ''
for item in self.state:
s += str(item)+'\n'
return s
class MonteCarloTreeSearchNode():
def __init__(self,state,parent=None,parent_action=None):
self.state = state
self.parent = parent
self.parent_action = parent_action
self.children = []
self._number_of_visits = 0
self._results = defaultdict(int)
self._results[1] = 0
self._results[-1] = 0
self._untried_actions = None
self._untried_actions = self.untried_actions()
return
def __str__(self):
s = str(self.state)
s += 'num_of_vis={}\n'.format(self._number_of_visits)
s += 'wins={} loss={}\n'.format(self._results[1],self._results[-1])
return s
def untried_actions(self):
self._untried_actions = self.state.get_legal_actions()
return self._untried_actions
def q(self):
wins = self._results[1]
loses = self._results[-1]
return wins-loses
def n(self):
return self._number_of_visits
def expand(self):
action = self._untried_actions.pop()
next_state = self.state.move(action)
child_node = MonteCarloTreeSearchNode(next_state,parent=self,parent_action=action)
self.children.append(child_node)
return child_node
def is_terminal_node(self):
return self.state.is_game_over()
def rollout(self):
current_rollout_state = self.state
visited_state = set()
visited_state.add(current_rollout_state)
#max_iter = 100
#iter_cnt = 0
while not current_rollout_state.is_game_over():
possible_moves = current_rollout_state.get_legal_actions()
possible_states = [current_rollout_state.move(a) for a in possible_moves]
for i in range(len(possible_states)-1,-1,-1):
if possible_states[i] in visited_state:
possible_states.pop(i)
if len(possible_states)==0:
return -1
current_rollout_state = possible_states[np.random.randint(len(possible_states))]
visited_state.add(current_rollout_state)
return current_rollout_state.game_result()
def backpropagate(self,result):
self._number_of_visits += 1
self._results[result] += 1
if self.parent:
self.parent.backpropagate(result)
def is_fully_expanded(self):
return len(self._untried_actions)==0
def best_child(self,c_param=0.1):
choices_weights = [(c.q() / c.n()) + c_param * np.sqrt((2 * np.log(self.n()) / c.n())) for c in self.children]
return self.children[np.argmax(choices_weights)]
def rollout_policy(self,possible_moves):
return possible_moves[np.random.randint(len(possible_moves))]
def _tree_policy(self):
current_node = self
while not current_node.is_terminal_node():
if not current_node.is_fully_expanded():
return current_node.expand()
else:
current_node = current_node.best_child()
return current_node
def best_action(self):
simulation_no = 100
for i in range(simulation_no):
v = self._tree_policy()
reward = v.rollout()
v.backpropagate(reward)
return self.best_child(c_param=0.)
if __name__=='__main__':
initial_state = State([[],[],[4,3,2,1],[2,6,1,5],[1,8,6,7],[4,9,2,7],[10,5,3,4],[10,7,4,10],[9,8,11,3],[5,12,3,9],[11,6,5,12],[12,11,2,6],[10,9,12,1],[11,8,8,7]])
root = MonteCarloTreeSearchNode(state = initial_state)
round = 0
while True:
print('rount:{}'.format(round))
round += 1
print(root)
if root.is_terminal_node():
break
selected_node = root.best_action()
root = selected_node