tenet commited on
Commit
dcdebd7
·
verified ·
1 Parent(s): cb58f46

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +6 -10
main.py CHANGED
@@ -1,7 +1,8 @@
1
  import math
2
  import random
3
  import numpy as np
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
  # Define the Node class for MCTS
7
  class Node:
@@ -110,15 +111,13 @@ class GameState:
110
 
111
  # Initialize the RWKV model and tokenizer
112
  model_name = "BlinkDL/rwkv-4-raven"
113
- tokenizer = AutoModelForCausalLM.from_pretrained(model_name)
114
- model = AutoModelForCausalLM.from_pretrained(model_name)
115
 
116
  # Generate Chain-of-Thought
117
  def generate_cot(state):
118
  input_text = f"Current state: {state}\nWhat is the best move?"
119
- inputs = tokenizer(input_text, return_tensors="pt")
120
- outputs = model.generate(**inputs, max_length=100, num_return_sequences=1)
121
- cot = tokenizer.decode(outputs[0], skip_special_tokens=True)
122
  return cot
123
 
124
  # Use CoT in MCTS
@@ -129,7 +128,4 @@ def mcts_with_cot(initial_state):
129
  return best_state, cot
130
 
131
  # Function to be called by Gradio
132
- def run_mcts_cot(initial_board):
133
- initial_state = GameState(initial_board, 1)
134
- best_state, cot = mcts_with_cot(initial_state)
135
- return str(best_state), cot
 
1
  import math
2
  import random
3
  import numpy as np
4
+ from rwkv.model import RWKV
5
+ from rwkv.utils import PIPELINE
6
 
7
  # Define the Node class for MCTS
8
  class Node:
 
111
 
112
  # Initialize the RWKV model and tokenizer
113
  model_name = "BlinkDL/rwkv-4-raven"
114
+ model = RWKV(model=model_name, strategy="cuda fp16")
115
+ pipeline = PIPELINE(model, "20B_tokenizer")
116
 
117
  # Generate Chain-of-Thought
118
  def generate_cot(state):
119
  input_text = f"Current state: {state}\nWhat is the best move?"
120
+ cot = pipeline.generate(input_text, max_tokens=100)
 
 
121
  return cot
122
 
123
  # Use CoT in MCTS
 
128
  return best_state, cot
129
 
130
  # Function to be called by Gradio
131
+ def run_mcts_c