File size: 2,070 Bytes
5eb74ff
 
 
 
 
 
 
 
 
 
 
 
ea899f2
 
5eb74ff
 
 
 
 
 
ea899f2
5eb74ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

# Load model and tokenizer
model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"

def load_model():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True  # This requires Accelerate
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

@spaces.GPU(duration=60)
def fix_code(input_code):
    messages = [
        {"role": "system", "content": "You are a helpful coding assistant. Please analyze the following code, identify any errors, and provide the corrected version."},
        {"role": "user", "content": f"Please fix this code:\n\n{input_code}"}
    ]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1024,
        temperature=0.7,
        top_p=0.95,
    )
    
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return response

iface = gr.Interface(
    fn=fix_code,
    inputs=gr.Code(
        label="Input Code",
        language="python",
        lines=10
    ),
    outputs=gr.Code(
        label="Corrected Code",
        language="python",
        lines=10
    ),
    title="Code Correction Tool",
    description="Enter your code with errors, and the AI will attempt to fix it.",
    examples=[
        ["def fibonacci(n):\n    if n = 0:\n        return 0\n    elif n == 1\n        return 1\n    else:\n        return fibonacci(n-1) + fibonacci(n-2)"],
        ["for i in range(10)\n    print(i"]
    ]
)

if __name__ == "__main__":
    iface.launch()