import gradio as gr from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch from PIL import Image import numpy as np import cv2 # OCR 모델 및 프로세서 초기화 processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') # 정답 및 해설 데이터베이스 (20문제) answer_key = { "1": {"answer": "민주주의", "explanation": "민주주의는 국민이 주인이 되어 나라의 중요한 일을 결정하는 제도입니다."}, "2": {"answer": "삼권분립", "explanation": "삼권분립은 입법부, 행정부, 사법부로 권력을 나누어 서로 견제와 균형을 이루게 하는 제도입니다."}, "3": {"answer": "지방자치제도", "explanation": "지방자치제도는 지역의 일을 그 지역 주민들이 직접 결정하고 처리하는 제도입니다."}, "4": {"answer": "헌법", "explanation": "헌법은 국가의 최고 법으로, 국민의 기본권과 정부 조직에 대한 기본 원칙을 담고 있습니다."}, "5": {"answer": "국회", "explanation": "국회는 법률을 만들고 정부를 감시하는 입법부의 역할을 담당합니다."}, # 6~20번까지 문제 추가 (실제 운영 시에는 여기에 추가) } def segment_answers(image): """시험지에서 답안 영역을 분할하는 함수""" if isinstance(image, np.ndarray): pil_image = Image.fromarray(image) else: return None # 이미지를 그레이스케일로 변환 gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # 이미지 이진화 _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV) # 윤곽선 찾기 contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # 답안 영역 추출 answer_regions = [] for contour in contours: x, y, w, h = cv2.boundingRect(contour) if w > 50 and h > 20: # 최소 크기 필터링 region = image[y:y+h, x:x+w] answer_regions.append({ 'image': region, 'position': (y, x) # y좌표로 정렬하기 위해 (y,x) 순서로 저장 }) # y좌표를 기준으로 정렬 (위에서 아래로) answer_regions.sort(key=lambda x: x['position'][0]) return [region['image'] for region in answer_regions] def recognize_text(image): """손글씨 인식 함수""" if isinstance(image, np.ndarray): image = Image.fromarray(image) pixel_values = processor(image, return_tensors="pt").pixel_values with torch.no_grad(): generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text def grade_answer(question_number, student_answer): """답안 채점 함수""" question_number = str(question_number) if question_number not in answer_key: return None correct_answer = answer_key[question_number]["answer"] explanation = answer_key[question_number]["explanation"] # 답안 비교 (띄어쓰기, 대소문자 무시) is_correct = student_answer.replace(" ", "").lower() == correct_answer.replace(" ", "").lower() return { "문제번호": question_number, "학생답안": student_answer, "정답여부": "O" if is_correct else "X", "정답": correct_answer, "해설": explanation } def process_full_exam(image): """전체 시험지 처리 함수""" if image is None or not isinstance(image, np.ndarray): return "시험지 이미지를 업로드해주세요." try: # 답안 영역 분할 answer_regions = segment_answers(image) if not answer_regions: return "답안 영역을 찾을 수 없습니다. 이미지를 확인해주세요." # 채점 결과 저장 results = [] total_correct = 0 # 각 답안 영역 처리 for idx, region in enumerate(answer_regions, 1): if idx > len(answer_key): # 정의된 문제 수를 초과하면 중단 break # 텍스트 인식 recognized_text = recognize_text(region) # 채점 result = grade_answer(idx, recognized_text) if result: results.append(result) if result["정답여부"] == "O": total_correct += 1 # 결과 포맷팅 score = (total_correct / len(results)) * 100 output = f"총점: {score:.1f}점 (20문제 중 {total_correct}개 정답)\n\n" output += "=== 상세 채점 결과 ===\n\n" for result in results: output += f""" [{result['문제번호']}번] {'✓' if result['정답여부']=='O' else '✗'} 학생답안: {result['학생답안']} 정답: {result['정답']} 해설: {result['해설']} """ return output except Exception as e: return f"처리 중 오류가 발생했습니다: {str(e)}" # Gradio 인터페이스 생성 iface = gr.Interface( fn=process_full_exam, inputs=gr.Image(label="시험지 이미지를 업로드하세요", type="numpy"), outputs=gr.Textbox(label="채점 결과"), title="초등학교 사회 시험지 채점 프로그램", description=""" 전체 시험지를 한 번에 채점하는 프로그램입니다. 시험지의 답안이 잘 보이도록 깨끗하게 스캔하거나 촬영해주세요. """, examples=[], # 예시 이미지를 추가할 수 있습니다 ) if __name__ == "__main__": iface.launch()