asdf / app.py
tkdehf2's picture
Update app.py
1adbe0b verified
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
from PIL import Image
import numpy as np
import cv2
# OCR ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ์ดˆ๊ธฐํ™”
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
# ์ •๋‹ต ๋ฐ ํ•ด์„ค ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค (20๋ฌธ์ œ)
answer_key = {
"1": {"answer": "๋ฏผ์ฃผ์ฃผ์˜", "explanation": "๋ฏผ์ฃผ์ฃผ์˜๋Š” ๊ตญ๋ฏผ์ด ์ฃผ์ธ์ด ๋˜์–ด ๋‚˜๋ผ์˜ ์ค‘์š”ํ•œ ์ผ์„ ๊ฒฐ์ •ํ•˜๋Š” ์ œ๋„์ž…๋‹ˆ๋‹ค."},
"2": {"answer": "์‚ผ๊ถŒ๋ถ„๋ฆฝ", "explanation": "์‚ผ๊ถŒ๋ถ„๋ฆฝ์€ ์ž…๋ฒ•๋ถ€, ํ–‰์ •๋ถ€, ์‚ฌ๋ฒ•๋ถ€๋กœ ๊ถŒ๋ ฅ์„ ๋‚˜๋ˆ„์–ด ์„œ๋กœ ๊ฒฌ์ œ์™€ ๊ท ํ˜•์„ ์ด๋ฃจ๊ฒŒ ํ•˜๋Š” ์ œ๋„์ž…๋‹ˆ๋‹ค."},
"3": {"answer": "์ง€๋ฐฉ์ž์น˜์ œ๋„", "explanation": "์ง€๋ฐฉ์ž์น˜์ œ๋„๋Š” ์ง€์—ญ์˜ ์ผ์„ ๊ทธ ์ง€์—ญ ์ฃผ๋ฏผ๋“ค์ด ์ง์ ‘ ๊ฒฐ์ •ํ•˜๊ณ  ์ฒ˜๋ฆฌํ•˜๋Š” ์ œ๋„์ž…๋‹ˆ๋‹ค."},
"4": {"answer": "ํ—Œ๋ฒ•", "explanation": "ํ—Œ๋ฒ•์€ ๊ตญ๊ฐ€์˜ ์ตœ๊ณ  ๋ฒ•์œผ๋กœ, ๊ตญ๋ฏผ์˜ ๊ธฐ๋ณธ๊ถŒ๊ณผ ์ •๋ถ€ ์กฐ์ง์— ๋Œ€ํ•œ ๊ธฐ๋ณธ ์›์น™์„ ๋‹ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค."},
"5": {"answer": "๊ตญํšŒ", "explanation": "๊ตญํšŒ๋Š” ๋ฒ•๋ฅ ์„ ๋งŒ๋“ค๊ณ  ์ •๋ถ€๋ฅผ ๊ฐ์‹œํ•˜๋Š” ์ž…๋ฒ•๋ถ€์˜ ์—ญํ• ์„ ๋‹ด๋‹นํ•ฉ๋‹ˆ๋‹ค."},
# 6~20๋ฒˆ๊นŒ์ง€ ๋ฌธ์ œ ์ถ”๊ฐ€ (์‹ค์ œ ์šด์˜ ์‹œ์—๋Š” ์—ฌ๊ธฐ์— ์ถ”๊ฐ€)
}
def segment_answers(image):
"""์‹œํ—˜์ง€์—์„œ ๋‹ต์•ˆ ์˜์—ญ์„ ๋ถ„ํ• ํ•˜๋Š” ํ•จ์ˆ˜"""
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image)
else:
return None
# ์ด๋ฏธ์ง€๋ฅผ ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ๋กœ ๋ณ€ํ™˜
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# ์ด๋ฏธ์ง€ ์ด์ง„ํ™”
_, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)
# ์œค๊ณฝ์„  ์ฐพ๊ธฐ
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# ๋‹ต์•ˆ ์˜์—ญ ์ถ”์ถœ
answer_regions = []
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if w > 50 and h > 20: # ์ตœ์†Œ ํฌ๊ธฐ ํ•„ํ„ฐ๋ง
region = image[y:y+h, x:x+w]
answer_regions.append({
'image': region,
'position': (y, x) # y์ขŒํ‘œ๋กœ ์ •๋ ฌํ•˜๊ธฐ ์œ„ํ•ด (y,x) ์ˆœ์„œ๋กœ ์ €์žฅ
})
# y์ขŒํ‘œ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌ (์œ„์—์„œ ์•„๋ž˜๋กœ)
answer_regions.sort(key=lambda x: x['position'][0])
return [region['image'] for region in answer_regions]
def recognize_text(image):
"""์†๊ธ€์”จ ์ธ์‹ ํ•จ์ˆ˜"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
pixel_values = processor(image, return_tensors="pt").pixel_values
with torch.no_grad():
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def grade_answer(question_number, student_answer):
"""๋‹ต์•ˆ ์ฑ„์  ํ•จ์ˆ˜"""
question_number = str(question_number)
if question_number not in answer_key:
return None
correct_answer = answer_key[question_number]["answer"]
explanation = answer_key[question_number]["explanation"]
# ๋‹ต์•ˆ ๋น„๊ต (๋„์–ด์“ฐ๊ธฐ, ๋Œ€์†Œ๋ฌธ์ž ๋ฌด์‹œ)
is_correct = student_answer.replace(" ", "").lower() == correct_answer.replace(" ", "").lower()
return {
"๋ฌธ์ œ๋ฒˆํ˜ธ": question_number,
"ํ•™์ƒ๋‹ต์•ˆ": student_answer,
"์ •๋‹ต์—ฌ๋ถ€": "O" if is_correct else "X",
"์ •๋‹ต": correct_answer,
"ํ•ด์„ค": explanation
}
def process_full_exam(image):
"""์ „์ฒด ์‹œํ—˜์ง€ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜"""
if image is None or not isinstance(image, np.ndarray):
return "์‹œํ—˜์ง€ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
try:
# ๋‹ต์•ˆ ์˜์—ญ ๋ถ„ํ• 
answer_regions = segment_answers(image)
if not answer_regions:
return "๋‹ต์•ˆ ์˜์—ญ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
# ์ฑ„์  ๊ฒฐ๊ณผ ์ €์žฅ
results = []
total_correct = 0
# ๊ฐ ๋‹ต์•ˆ ์˜์—ญ ์ฒ˜๋ฆฌ
for idx, region in enumerate(answer_regions, 1):
if idx > len(answer_key): # ์ •์˜๋œ ๋ฌธ์ œ ์ˆ˜๋ฅผ ์ดˆ๊ณผํ•˜๋ฉด ์ค‘๋‹จ
break
# ํ…์ŠคํŠธ ์ธ์‹
recognized_text = recognize_text(region)
# ์ฑ„์ 
result = grade_answer(idx, recognized_text)
if result:
results.append(result)
if result["์ •๋‹ต์—ฌ๋ถ€"] == "O":
total_correct += 1
# ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
score = (total_correct / len(results)) * 100
output = f"์ด์ : {score:.1f}์  (20๋ฌธ์ œ ์ค‘ {total_correct}๊ฐœ ์ •๋‹ต)\n\n"
output += "=== ์ƒ์„ธ ์ฑ„์  ๊ฒฐ๊ณผ ===\n\n"
for result in results:
output += f"""
[{result['๋ฌธ์ œ๋ฒˆํ˜ธ']}๋ฒˆ] {'โœ“' if result['์ •๋‹ต์—ฌ๋ถ€']=='O' else 'โœ—'}
ํ•™์ƒ๋‹ต์•ˆ: {result['ํ•™์ƒ๋‹ต์•ˆ']}
์ •๋‹ต: {result['์ •๋‹ต']}
ํ•ด์„ค: {result['ํ•ด์„ค']}
"""
return output
except Exception as e:
return f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
iface = gr.Interface(
fn=process_full_exam,
inputs=gr.Image(label="์‹œํ—˜์ง€ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”", type="numpy"),
outputs=gr.Textbox(label="์ฑ„์  ๊ฒฐ๊ณผ"),
title="์ดˆ๋“ฑํ•™๊ต ์‚ฌํšŒ ์‹œํ—˜์ง€ ์ฑ„์  ํ”„๋กœ๊ทธ๋žจ",
description="""
์ „์ฒด ์‹œํ—˜์ง€๋ฅผ ํ•œ ๋ฒˆ์— ์ฑ„์ ํ•˜๋Š” ํ”„๋กœ๊ทธ๋žจ์ž…๋‹ˆ๋‹ค.
์‹œํ—˜์ง€์˜ ๋‹ต์•ˆ์ด ์ž˜ ๋ณด์ด๋„๋ก ๊นจ๋—ํ•˜๊ฒŒ ์Šค์บ”ํ•˜๊ฑฐ๋‚˜ ์ดฌ์˜ํ•ด์ฃผ์„ธ์š”.
""",
examples=[], # ์˜ˆ์‹œ ์ด๋ฏธ์ง€๋ฅผ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
)
if __name__ == "__main__":
iface.launch()