tenet commited on
Commit
157810d
·
verified ·
1 Parent(s): e7fd1b9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +19 -5
main.py CHANGED
@@ -1,8 +1,15 @@
 
 
1
  import math
2
  import random
3
  import numpy as np
 
4
  from rwkv.model import RWKV
5
- from rwkv.utils import PIPELINE
 
 
 
 
6
 
7
  # Define the Node class for MCTS
8
  class Node:
@@ -111,13 +118,17 @@ class GameState:
111
 
112
  # Initialize the RWKV model and tokenizer
113
  model_name = "BlinkDL/rwkv-4-raven"
114
- model = RWKV(model=model_name, strategy="cuda fp16")
115
- pipeline = PIPELINE(model, "20B_tokenizer")
 
 
116
 
117
  # Generate Chain-of-Thought
118
  def generate_cot(state):
119
  input_text = f"Current state: {state}\nWhat is the best move?"
120
- cot = pipeline.generate(input_text, max_tokens=100)
 
 
121
  return cot
122
 
123
  # Use CoT in MCTS
@@ -128,4 +139,7 @@ def mcts_with_cot(initial_state):
128
  return best_state, cot
129
 
130
  # Function to be called by Gradio
131
- def run_mcts_c
 
 
 
 
1
+ import sys
2
+ import os
3
  import math
4
  import random
5
  import numpy as np
6
+ import gradio as gr
7
  from rwkv.model import RWKV
8
+ from transformers import AutoTokenizer
9
+
10
+ # Add RWKV directory to Python path
11
+ rwkv_dir = os.path.join(os.getcwd(), 'RWKV-LM')
12
+ sys.path.append(rwkv_dir)
13
 
14
  # Define the Node class for MCTS
15
  class Node:
 
118
 
119
  # Initialize the RWKV model and tokenizer
120
  model_name = "BlinkDL/rwkv-4-raven"
121
+ tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a tokenizer from a supported model
122
+
123
+ # Load the RWKV model
124
+ model = RWKV(model_name=model_name, strategy="cuda fp16")
125
 
126
  # Generate Chain-of-Thought
127
  def generate_cot(state):
128
  input_text = f"Current state: {state}\nWhat is the best move?"
129
+ inputs = tokenizer(input_text, return_tensors="pt")
130
+ outputs = model.generate(inputs.input_ids, max_length=100, num_return_sequences=1)
131
+ cot = tokenizer.decode(outputs[0], skip_special_tokens=True)
132
  return cot
133
 
134
  # Use CoT in MCTS
 
139
  return best_state, cot
140
 
141
  # Function to be called by Gradio
142
+ def run_mcts_cot(initial_board):
143
+ initial_state = GameState(initial_board, 1)
144
+ best_state, cot = mcts_with_cot(initial_state)
145
+ return str(best_state), cot