tenet commited on
Commit
5f6e38c
·
verified ·
1 Parent(s): 17c510a

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +149 -0
main.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import math
4
+ import random
5
+ import numpy as np
6
+
7
+ # Define the Node class for MCTS
8
+ class Node:
9
+ def __init__(self, state, parent=None):
10
+ self.state = state
11
+ self.parent = parent
12
+ self.children = []
13
+ self.visits = 0
14
+ self.wins = 0
15
+
16
+ def is_fully_expanded(self):
17
+ return len(self.children) > 0
18
+
19
+ def best_child(self, c_param=1.4):
20
+ choices_weights = [
21
+ (child.wins / child.visits) + c_param * (2 * math.log(self.visits) / child.visits) ** 0.5 for child in self.children
22
+ ]
23
+ return self.children[np.argmax(choices_weights)]
24
+
25
+ def expand(self, state):
26
+ new_node = Node(state, self)
27
+ self.children.append(new_node)
28
+ return new_node
29
+
30
+ # Define the MCTS class
31
+ class MCTS:
32
+ def __init__(self, simulation_limit=1000):
33
+ self.root = None
34
+ self.simulation_limit = simulation_limit
35
+
36
+ def search(self, initial_state):
37
+ self.root = Node(initial_state)
38
+ for _ in range(self.simulation_limit):
39
+ node = self.tree_policy(self.root)
40
+ reward = self.default_policy(node.state)
41
+ self.backpropagate(node, reward)
42
+ return self.root.best_child(c_param=0).state
43
+
44
+ def tree_policy(self, node):
45
+ while not node.state.is_terminal():
46
+ if not node.is_fully_expanded():
47
+ return self.expand(node)
48
+ else:
49
+ node = node.best_child()
50
+ return node
51
+
52
+ def expand(self, node):
53
+ tried_states = [child.state for child in node.children]
54
+ new_state = node.state.get_random_child_state()
55
+ while new_state in tried_states:
56
+ new_state = node.state.get_random_child_state()
57
+ return node.expand(new_state)
58
+
59
+ def default_policy(self, state):
60
+ while not state.is_terminal():
61
+ state = state.get_random_child_state()
62
+ return state.get_reward()
63
+
64
+ def backpropagate(self, node, reward):
65
+ while node is not None:
66
+ node.visits += 1
67
+ node.wins += reward
68
+ node = node.parent
69
+
70
+ # Define the Game State and Rules
71
+ class GameState:
72
+ def __init__(self, board, player):
73
+ self.board = board
74
+ self.player = player
75
+
76
+ def is_terminal(self):
77
+ return self.check_win() or self.check_draw()
78
+
79
+ def check_win(self):
80
+ for row in self.board:
81
+ if row.count(row[0]) == len(row) and row[0] != 0:
82
+ return True
83
+ for col in range(len(self.board)):
84
+ if self.board[0][col] == self.board[1][col] == self.board[2][col] and self.board[0][col] != 0:
85
+ return True
86
+ if self.board[0][0] == self.board[1][1] == self.board[2][2] and self.board[0][0] != 0:
87
+ return True
88
+ if self.board[0][2] == self.board[1][1] == self.board[2][0] and self.board[0][2] != 0:
89
+ return True
90
+ return False
91
+
92
+ def check_draw(self):
93
+ return all(self.board[row][col] != 0 for row in range(len(self.board)) for col in range(len(self.board)))
94
+
95
+ def get_random_child_state(self):
96
+ available_moves = [(row, col) for row in range(len(self.board)) for col in range(len(self.board)) if self.board[row][col] == 0]
97
+ if not available_moves:
98
+ return self
99
+ row, col = random.choice(available_moves)
100
+ new_board = [row.copy() for row in self.board]
101
+ new_board[row][col] = self.player
102
+ return GameState(new_board, 3 - self.player)
103
+
104
+ def get_reward(self):
105
+ if self.check_win():
106
+ return 1 if self.player == 1 else -1
107
+ return 0
108
+
109
+ def __str__(self):
110
+ return "\n".join(" ".join(str(cell) for cell in row) for row in self.board)
111
+
112
+ # Initialize the RWKV model and tokenizer
113
+ model_name = "BlinkDL/rwkv-4-raven"
114
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
115
+ model = AutoModelForCausalLM.from_pretrained(model_name)
116
+
117
+ # Generate Chain-of-Thought
118
+ def generate_cot(state):
119
+ input_text = f"Current state: {state}\nWhat is the best move?"
120
+ inputs = tokenizer(input_text, return_tensors="pt")
121
+ outputs = model.generate(**inputs, max_length=100, num_return_sequences=1)
122
+ cot = tokenizer.decode(outputs[0], skip_special_tokens=True)
123
+ return cot
124
+
125
+ # Use CoT in MCTS
126
+ def mcts_with_cot(initial_state):
127
+ mcts = MCTS(simulation_limit=1000)
128
+ best_state = mcts.search(initial_state)
129
+ cot = generate_cot(best_state)
130
+ return best_state, cot
131
+
132
+ # Gradio Interface
133
+ def run_mcts_cot(initial_board):
134
+ initial_state = GameState(initial_board, 1)
135
+ best_state, cot = mcts_with_cot(initial_state)
136
+ return str(best_state), cot
137
+
138
+ # Create the Gradio interface
139
+ initial_board = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
140
+ iface = gr.Interface(
141
+ fn=run_mcts_cot,
142
+ inputs=gr.inputs.JSON(),
143
+ outputs=["text", "text"],
144
+ title="RWKV CoT Demo for MCTS",
145
+ description="This demo uses RWKV to generate Chain-of-Thought reasoning to guide the MCTS algorithm in a Tic-Tac-Toe game."
146
+ )
147
+
148
+ # Launch the interface
149
+ iface.launch()