Spaces:
Running
Running
import json | |
import torch | |
import random | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
import numpy as np | |
# Load dataset pelatihan untuk mendapatkan tags dan respons | |
with open('datasets_new.json', 'r') as f: | |
datasets = json.load(f) | |
# Dapatkan daftar tag | |
tags = sorted(set(dataset['tag'] for dataset in datasets['intents'])) | |
# Load IndoBERT model dan tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("indobert_model") | |
model = AutoModelForSequenceClassification.from_pretrained("indobert_model") | |
# Gunakan GPU kalau ada | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
# Fungsi get_response (disalin dari chat_indobert.py untuk menghindari import) | |
def get_response(msg): | |
# Tokenisasi input | |
encoding = tokenizer(msg, padding='max_length', truncation=True, max_length=20, return_tensors='pt') | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
# Prediksi | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
probs = torch.softmax(outputs.logits, dim=1) | |
prob_list = probs[0].tolist() | |
print("Probabilities for each tag:") | |
for tag, prob in zip(tags, prob_list): | |
print(f"Tag: {tag}, Prob: {prob:.4f}") | |
max_prob, predicted = torch.max(probs, dim=1) | |
tag = tags[predicted.item()] | |
prob = max_prob.item() | |
print(f"Max Prob: {prob}, Predicted Tag: {tag}") | |
if prob > 0.2: # Ambang batas 0.2 | |
for dataset in datasets['intents']: | |
if tag == dataset["tag"]: | |
return random.choice(dataset['responses']), tag | |
return "Hi, maaf saya tidak tahu soal itu...", 'notfound' | |
# Muat dataset uji | |
try: | |
with open('datasets_uji_mandiri.json', 'r') as f: | |
test_data = json.load(f) | |
except FileNotFoundError: | |
print("Error: File 'datasets_uji_mandiri.json' tidak ditemukan. Pastikan file ada di direktori saat ini.") | |
exit(1) | |
# Inisialisasi list untuk menyimpan hasil | |
true_intents = [] | |
predicted_intents = [] | |
errors = [] | |
# Proses setiap pertanyaan | |
for item in test_data: | |
question = item['text'] | |
true_intent = item['true_intent'] | |
# Dapatkan respons dan intent yang diprediksi | |
response, predicted_intent = get_response(question) | |
# Simpan hasil | |
true_intents.append(true_intent) | |
predicted_intents.append(predicted_intent) | |
# Catat kesalahan | |
if predicted_intent != true_intent: | |
errors.append({ | |
"question": question, | |
"true_intent": true_intent, | |
"predicted_intent": predicted_intent, | |
"response": response | |
}) | |
# Hitung metrik | |
accuracy = accuracy_score(true_intents, predicted_intents) * 100 | |
precision, recall, f1, _ = precision_recall_fscore_support(true_intents, predicted_intents, average='weighted', zero_division=0) | |
# Confusion matrix | |
conf_matrix = confusion_matrix(true_intents, predicted_intents, labels=tags) | |
# Analisis false intent | |
false_intent_errors = [e for e in errors if e['true_intent'] == 'notfound'] | |
false_intent_error_rate = len(false_intent_errors) / sum(1 for item in test_data if item['true_intent'] == 'notfound') * 100 if sum(1 for item in test_data if item['true_intent'] == 'notfound') > 0 else 0 | |
# Cetak hasil | |
print(f"Akurasi: {accuracy:.2f}%") | |
print(f"Presisi: {precision:.4f}") | |
print(f"Recall: {recall:.4f}") | |
print(f"F1-Score: {f1:.4f}") | |
print(f"Kesalahan Total: {len(errors)} kasus") | |
print(f"Kesalahan pada False Intent: {len(false_intent_errors)} kasus ({false_intent_error_rate:.2f}% dari false intent)") | |
# Cetak kesalahan | |
print("\nDetail Kesalahan:") | |
for error in errors: | |
print(f"Pertanyaan: {error['question']}") | |
print(f"Benar: {error['true_intent']}, Prediksi: {error['predicted_intent']}") | |
print(f"Respons: {error['response']}\n") | |
# Cetak confusion matrix | |
print("Confusion Matrix:") | |
print(conf_matrix) |