https://ai-boson.github.io/mcts/

https://en.wikipedia.org/wiki/Monte_Carlo_tree_search

之前我没了解过博弈树,蒙特卡洛搜索方面的知识。当时我觉得这个问题似乎分支不多的样子,就直接bfs暴力搜索,结果搜到30层左右就爆内存了(60g+的内存)。即使我加入了一些基于评分的剪枝和排序,还是无法突破35层。

现在用了mcts方法,真的是amazing!!!

这里需要注意的是,water sort puzzle 相比较于其他游戏,比如五子棋等等,存在循环搜索的问题。所以如果不加限制,可能会导致rollout陷入很长时间的循环。所以我在rollout加入了set来防止循环。

import numpy as np 
from collections import defaultdict
from copy import copy,deepcopy

class State():
	def __init__(self,state):
		self.state = state 
		assert isinstance(self.state,list)
		for item in state:
			assert isinstance(item,list)
		self.num_of_container = len(self.state)
		cnt = 0
		for item in state:
			cnt += len(item)
		assert cnt%4==0
		self.num_of_color = cnt//4

	def __eq__(self,other):
		if not isinstance(other,State):
			return False
		return self.state==other.state

	def __hash__(self):
		r = 0
		for container in self.state:
			for color in container:
				r += hash(color)*37
		return r

	@staticmethod
	def same_color_in_container(container):
		if len(container)==0:
			return True 
		color = container[0]
		for i in range(1,len(container)):
			if container[i]!=color:
				return False
		return True

	def __check(self,from_idx,to_idx):
		if from_idx<0 or from_idx>=self.num_of_container or to_idx<0 or to_idx>=self.num_of_container or from_idx==to_idx:
			return False
		fv = self.state[from_idx]
		tv = self.state[to_idx]
		if len(fv)==0:
			return False
		if len(tv)==4:
			return False
		if len(tv)>0 and tv[-1]!=fv[-1]:
			return False
		if len(fv)==4 and fv[0]==fv[1]==fv[2]==fv[3]:
			return False
		if len(tv)==0 and State.same_color_in_container(fv):
			return False
		return True

	def get_legal_actions(self):
		actions = []
		for i in range(self.num_of_container):
			for j in range(self.num_of_container):
				if self.__check(i,j):
					actions.append((i,j))
		return actions

	def move(self,action):
		ret = deepcopy(self)
		from_idx = action[0]
		to_idx = action[1]
		fv = ret.state[from_idx]
		tv = ret.state[to_idx]
		c = fv[-1]
		cnt = 1
		i = len(fv)-2
		while i>=0 and fv[i]==c:
			i -= 1
			cnt += 1 
		cnt = min(cnt,4-len(tv))
		for i in range(cnt):
			fv.pop()
			tv.append(c)
		return ret

	def game_result(self):
		for v in self.state:
			if len(v)==0:
				continue
			if len(v)<4:
				return -1
			for i in range(3):
				if v[i]!=v[i+1]:
					return -1
		return 1

	def is_game_over(self):
		return len(self.get_legal_actions())==0

	def __str__(self):
		s = ''
		for item in self.state:
			s += str(item)+'\n'
		return s




class MonteCarloTreeSearchNode():
	def __init__(self,state,parent=None,parent_action=None):
		self.state = state
		self.parent = parent
		self.parent_action = parent_action
		self.children = []
		self._number_of_visits = 0
		self._results = defaultdict(int)
		self._results[1] = 0
		self._results[-1] = 0
		self._untried_actions = None
		self._untried_actions = self.untried_actions()
		return

	def __str__(self):
		s = str(self.state)
		s += 'num_of_vis={}\n'.format(self._number_of_visits)
		s += 'wins={} loss={}\n'.format(self._results[1],self._results[-1])
		return s

	def untried_actions(self):
		self._untried_actions = self.state.get_legal_actions()
		return self._untried_actions

	def q(self):
		wins = self._results[1]
		loses = self._results[-1]
		return wins-loses

	def n(self):
		return self._number_of_visits

	def expand(self):
		action = self._untried_actions.pop()
		next_state = self.state.move(action)
		child_node = MonteCarloTreeSearchNode(next_state,parent=self,parent_action=action)
		self.children.append(child_node)
		return child_node

	def is_terminal_node(self):
		return self.state.is_game_over()

	def rollout(self):
		current_rollout_state = self.state
		visited_state = set()
		visited_state.add(current_rollout_state)
		#max_iter = 100
		#iter_cnt = 0
		while not current_rollout_state.is_game_over():
			possible_moves = current_rollout_state.get_legal_actions()
			possible_states = [current_rollout_state.move(a) for a in possible_moves]

			for i in range(len(possible_states)-1,-1,-1):
				if possible_states[i] in visited_state:
					possible_states.pop(i)
			if len(possible_states)==0:
				return -1
			current_rollout_state = possible_states[np.random.randint(len(possible_states))]
			visited_state.add(current_rollout_state)
		return current_rollout_state.game_result()

	def backpropagate(self,result):
		self._number_of_visits += 1
		self._results[result] += 1
		if self.parent:
			self.parent.backpropagate(result)

	def is_fully_expanded(self):
		return len(self._untried_actions)==0

	def best_child(self,c_param=0.1):
	    choices_weights = [(c.q() / c.n()) + c_param * np.sqrt((2 * np.log(self.n()) / c.n())) for c in self.children]
	    return self.children[np.argmax(choices_weights)]

	def rollout_policy(self,possible_moves):
		return possible_moves[np.random.randint(len(possible_moves))]

	def _tree_policy(self):
		current_node = self
		while not current_node.is_terminal_node():
			if not current_node.is_fully_expanded():
				return current_node.expand()
			else:
				current_node = current_node.best_child()
		return current_node

	def best_action(self):
		simulation_no = 100
		for i in range(simulation_no):
			v = self._tree_policy()
			reward = v.rollout()
			v.backpropagate(reward)
		return self.best_child(c_param=0.)

if __name__=='__main__':
	
	initial_state = State([[],[],[4,3,2,1],[2,6,1,5],[1,8,6,7],[4,9,2,7],[10,5,3,4],[10,7,4,10],[9,8,11,3],[5,12,3,9],[11,6,5,12],[12,11,2,6],[10,9,12,1],[11,8,8,7]])
	root = MonteCarloTreeSearchNode(state = initial_state)
	round = 0
	while True:
		print('rount:{}'.format(round))
		round += 1
		print(root)
		if root.is_terminal_node():
			break
		selected_node = root.best_action()
		root = selected_node
		

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注