tic_tac_toe / model.py
remosleandre
[FIX] weight_update
b46b06b
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import csv
class Architecture(nn.Module):
def __init__(self) -> None:
super(Architecture, self).__init__()
self.input_size = 9
self.hidden_size_1 = 9
self.hidden_size_2 = 9
self.hidden_size_3 = 9
self.hidden_size_4 = 9
self.hidden_size_5 = 9
self.hidden_size_6 = 9
self.hidden_size_7 = 9
self.output_size = 9
self.fc1 = nn.Linear(self.input_size, self.hidden_size_1)
self.fc2 = nn.Linear(self.hidden_size_1, self.hidden_size_2)
self.fc3 = nn.Linear(self.hidden_size_2, self.hidden_size_3)
self.fc4 = nn.Linear(self.hidden_size_3, self.hidden_size_4)
self.fc5 = nn.Linear(self.hidden_size_4, self.hidden_size_5)
self.fc6 = nn.Linear(self.hidden_size_5, self.hidden_size_6)
self.fc7 = nn.Linear(self.hidden_size_6, self.hidden_size_7)
self.fc8 = nn.Linear(self.hidden_size_7, self.output_size)
self.loss = nn.CrossEntropyLoss()
self.relu = nn.ReLU()
def inference(self, x):
with torch.no_grad():
x1 = self.relu(self.fc1(x))
x2 = self.relu(self.fc2(x1))
x3 = self.relu(self.fc3(x2))
x4 = self.relu(self.fc4(x3))
x5 = self.relu(self.fc5(x4))
x6 = self.relu(self.fc6(x5))
x7 = self.relu(self.fc7(x6))
x8 = self.fc8(x7)
return x8
def load_model():
model = Architecture()
model.load_state_dict(torch.load('./model_weights.pth'))
return model
def inference_model(model, input):
return model.inference(input)
def minimax(board, depth, is_maximizing):
if check_winner(board) == 2: # AI wins
return 1
if check_winner(board) == 1: # Human wins
return -1
if check_winner(board) == 3: # Draw
return 0
if is_maximizing:
best_score = -math.inf
for i in range(9):
if board[i] == 0:
board[i] = 2 # AI's move
score = minimax(board, depth + 1, False)
board[i] = 0
best_score = max(score, best_score)
return best_score
else:
best_score = math.inf
for i in range(9):
if board[i] == 0:
board[i] = 1 # Human's move
score = minimax(board, depth + 1, True)
board[i] = 0
best_score = min(score, best_score)
return best_score
def best_move_minimax(board):
best_score = -math.inf
move = None
for i in range(9):
if board[i] == 0:
board[i] = 2 # AI's move
score = minimax(board, 0, False)
board[i] = 0
if score > best_score:
best_score = score
move = i
return move
def oracle(board, position):
output_board = [0] * 9
output_board[position] = 2
# Get minimax prediction
minimax_move = best_move_minimax(board)
# Compare oracle's move with minimax prediction
if minimax_move == position:
return 1 # Predictions match
else:
return -1 # Predictions differ
def print_in_csv(board, position, comparison_result):
output_board = [0] * 9
output_board[position] = 2
with open('game_state.csv', 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow([
','.join(map(str, board)),
','.join(map(str, output_board)),
comparison_result
])
def decode_prediction(prediction, board):
result = torch.argmax(prediction)
position = result.item()
comparison_result = oracle(board, position)
print(f'Oracle comparison result: {comparison_result}')
print_in_csv(board, position, comparison_result)
board[position] = 2
return board
def encode_input(input):
return torch.tensor(input, dtype=torch.float32)
def play():
model = load_model()
model.eval()
print(model.inference(torch.randn(9)))
def print_board(board):
visual_board = [' ' if x == 0 else 'X' if x == 1 else 'O' for x in board]
print(f'{visual_board[0]} | {visual_board[1]} | {visual_board[2]}')
print('---------')
print(f'{visual_board[3]} | {visual_board[4]} | {visual_board[5]}')
print('---------')
print(f'{visual_board[6]} | {visual_board[7]} | {visual_board[8]}')
print('---------')
def check_winner(board):
# Check rows
for i in range(0, 9, 3):
if board[i] == board[i + 1] == board[i + 2] == 1:
return 1
if board[i] == board[i + 1] == board[i + 2] == 2:
return 2
# Check columns
for i in range(3):
if board[i] == board[i + 3] == board[i + 6] == 1:
return 1
if board[i] == board[i + 3] == board[i + 6] == 2:
return 2
# Check diagonals
if board[0] == board[4] == board[8] == 1:
return 1
if board[0] == board[4] == board[8] == 2:
return 2
if board[2] == board[4] == board[6] == 1:
return 1
if board[2] == board[4] == board[6] == 2:
return 2
# Check for a draw
if all(cell != 0 for cell in board):
return 3
return 0
def simulate_game():
board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
end = False
while end != True:
print_board(board)
move = int(input('Enter move: '))
if move <= 9 and move >= 1 and board[move - 1] == 0:
board[move - 1] = 1
print_board(board)
if check_winner(board) == 1:
print('You win')
end = True
break
if check_winner(board) == 3:
print('Draw')
end = True
break
print('Computer is thinking...')
intput = encode_input(board)
model = load_model()
prediction = inference_model(model, intput)
output = decode_prediction(prediction, board)
board = output
if check_winner(board) == 2:
print_board(board)
print('Computer wins')
end = True
break
if check_winner(board) == 3:
print('Draw')
end = True
break
else:
print('Invalid move')
if __name__ == '__main__':
simulate_game()