szili2011 commited on
Commit
0cf02bf
·
verified ·
1 Parent(s): 6925566

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import json
4
+ import gradio as gr
5
+
6
+ # --- Step 1: Load the vocabularies ---
7
+ # These files are in your Hugging Face Space repository, so we can load them directly.
8
+ with open('char_to_int.json', 'r') as f:
9
+ char_to_int = json.load(f)
10
+ with open('int_to_lang.json', 'r') as f:
11
+ int_to_lang = json.load(f)
12
+
13
+ # --- Step 2: Re-define the Model Architecture ---
14
+ # This MUST be the exact same architecture as the one you trained.
15
+ # All the hyperparameters (embedding_dim, hidden_dim, etc.) must match.
16
+ class CodeClassifierRNN(nn.Module):
17
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, pad_idx):
18
+ super().__init__()
19
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
20
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout if n_layers > 1 else 0, batch_first=True)
21
+ self.dropout = nn.Dropout(dropout)
22
+ self.fc = nn.Linear(hidden_dim * 2, output_dim) # * 2 for bidirectional
23
+ def forward(self, text):
24
+ embedded = self.embedding(text)
25
+ _, (hidden, _) = self.lstm(embedded)
26
+ hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
27
+ hidden = self.dropout(hidden)
28
+ output = self.fc(hidden)
29
+ return output
30
+
31
+ # --- Step 3: Instantiate the model and load the trained weights ---
32
+ # Set hyperparameters to match your training script
33
+ PAD_IDX = char_to_int['<PAD>']
34
+ VOCAB_SIZE = len(char_to_int)
35
+ EMBEDDING_DIM = 128
36
+ HIDDEN_DIM = 192 # Must match the final trained model
37
+ OUTPUT_DIM = len(int_to_lang)
38
+ N_LAYERS = 2
39
+ BIDIRECTIONAL = True
40
+ DROPOUT = 0.5
41
+
42
+ # Create an instance of the model
43
+ model = CodeClassifierRNN(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT, PAD_IDX)
44
+
45
+ # Load the saved state dictionary.
46
+ # We use map_location='cpu' because the Space runs on a CPU.
47
+ model.load_state_dict(torch.load('polyglot_classifier.pt', map_location='cpu'))
48
+ model.eval() # Set the model to evaluation mode
49
+
50
+ # --- Step 4: Create the prediction function ---
51
+ def classify_code(code_snippet):
52
+ if not code_snippet:
53
+ return {}
54
+
55
+ # 1. Convert snippet to tensor of indices
56
+ indexed = [char_to_int.get(c, char_to_int['<UNK>']) for c in code_snippet]
57
+ tensor = torch.LongTensor(indexed).unsqueeze(0) # Add batch dimension
58
+
59
+ # 2. Make prediction
60
+ with torch.no_grad():
61
+ prediction = model(tensor)
62
+
63
+ # 3. Get probabilities using softmax
64
+ probabilities = torch.softmax(prediction, dim=1)
65
+
66
+ # 4. Get top 5 predictions
67
+ top5_probs, top5_indices = torch.topk(probabilities, 5)
68
+
69
+ # 5. Format for Gradio output
70
+ confidences = {int_to_lang[str(idx.item())]: prob.item() for idx, prob in zip(top5_indices[0], top5_probs[0])}
71
+
72
+ return confidences
73
+
74
+ # --- Step 5: Create and launch the Gradio Interface ---
75
+ iface = gr.Interface(
76
+ fn=classify_code,
77
+ inputs=gr.Code(language=None, label="Code Snippet"),
78
+ outputs=gr.Label(num_top_classes=5, label="Predicted Language"),
79
+ title="Polyglot Code Classifier",
80
+ description="Enter a code snippet to see which programming language the AI thinks it is. This model was trained from scratch on a custom dataset.",
81
+ examples=[
82
+ ["def hello_world():\n print('Hello from Python!')"],
83
+ ["function greet() {\n console.log('Hello from JavaScript!');\n}"],
84
+ ["public class Main {\n public static void main(String[] args) {\n System.out.println(\"Hello, Java!\");\n }\n}"]
85
+ ]
86
+ )
87
+
88
+ iface.launch()