Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import json | |
import gradio as gr | |
# --- Step 1: Create a "Smart" Vocabulary Loader --- | |
# This function will load the vocabularies and automatically fix any format mismatches. | |
def load_vocabularies(): | |
""" | |
Loads vocabularies and intelligently determines the correct format, | |
preventing crashes due to misnamed files. | |
""" | |
with open('char_to_int.json', 'r', encoding='utf-8') as f: | |
char_to_int_map = json.load(f) | |
# Load the file the user has named 'int_to_lang.json'. | |
with open('int_to_lang.json', 'r', encoding='utf-8') as f: | |
language_vocab = json.load(f) | |
# Get the first key to check the format (e.g., is it "0" or "C#") | |
first_key = next(iter(language_vocab)) | |
int_to_lang_map = {} | |
try: | |
# Try to convert the first key to an integer. | |
int(first_key) | |
# If this SUCCEEDS, the file is in the correct {"0": "Language"} format. | |
print("[INFO] Detected int->lang format. Loading directly.") | |
int_to_lang_map = {int(k): v for k, v in language_vocab.items()} | |
except ValueError: | |
# If this FAILS, the file is in the {"Language": 0} format. | |
# We must reverse it to create the correct int->lang map. | |
print("[INFO] Detected lang->int format. Reversing dictionary to fix.") | |
int_to_lang_map = {v: k for k, v in language_vocab.items()} | |
return char_to_int_map, int_to_lang_map | |
# Load the vocabularies using our smart function | |
char_to_int, int_to_lang = load_vocabularies() | |
# --- Step 2: Re-define the Model Architecture --- | |
# This MUST be the exact same architecture as the one you trained. | |
class CodeClassifierRNN(nn.Module): | |
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, pad_idx): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) | |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout if n_layers > 1 else 0, batch_first=True) | |
self.dropout = nn.Dropout(dropout) | |
self.fc = nn.Linear(hidden_dim * 2, output_dim) # * 2 for bidirectional | |
def forward(self, text): | |
embedded = self.embedding(text) | |
_, (hidden, _) = self.lstm(embedded) | |
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) | |
hidden = self.dropout(hidden) | |
output = self.fc(hidden) | |
return output | |
# --- Step 3: Instantiate the model and load the trained weights --- | |
PAD_IDX = char_to_int['<PAD>'] | |
VOCAB_SIZE = len(char_to_int) | |
EMBEDDING_DIM = 128 | |
HIDDEN_DIM = 192 | |
OUTPUT_DIM = len(int_to_lang) | |
N_LAYERS = 2 | |
BIDIRECTIONAL = True | |
DROPOUT = 0.5 | |
model = CodeClassifierRNN(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT, PAD_IDX) | |
model.load_state_dict(torch.load('polyglot_classifier.pt', map_location='cpu')) | |
model.eval() | |
# --- Step 4: Create the prediction function --- | |
def classify_code(code_snippet): | |
if not code_snippet or not code_snippet.strip(): | |
return {} | |
indexed = [char_to_int.get(c, char_to_int['<UNK>']) for c in code_snippet] | |
tensor = torch.LongTensor(indexed).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = model(tensor) | |
probabilities = torch.softmax(prediction, dim=1) | |
top5_probs, top5_indices = torch.topk(probabilities, 5) | |
# This lookup will now work regardless of the original file format. | |
confidences = {int_to_lang[idx.item()]: prob.item() for idx, prob in zip(top5_indices[0], top5_probs[0])} | |
return confidences | |
# --- Step 5: Create and launch the Gradio Interface --- | |
iface = gr.Interface( | |
fn=classify_code, | |
inputs=gr.Code(language=None, label="Code Snippet"), | |
outputs=gr.Label(num_top_classes=5, label="Predicted Language"), | |
title="Polyglot Code Classifier", | |
description="Enter a code snippet to see which programming language the AI thinks it is. This model was trained from scratch on a custom dataset.", | |
examples=[ | |
["def hello_world():\n print('Hello from Python!')"], | |
["function greet() {\n console.log('Hello from JavaScript!');\n}"], | |
["public class Main {\n public static void main(String[] args) {\n System.out.println(\"Hello, Java!\");\n }\n}"] | |
] | |
) | |
iface.launch() |