Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,229 +1,177 @@
|
|
1 |
# app.py
|
2 |
-
# ZeroGPU-
|
3 |
-
#
|
4 |
-
#
|
5 |
-
#
|
6 |
-
#
|
7 |
-
#
|
8 |
-
|
9 |
-
#
|
10 |
|
11 |
import os
|
12 |
-
|
|
|
|
|
13 |
|
14 |
-
import spaces # MUST be imported before anything that may touch CUDA
|
15 |
import gradio as gr
|
|
|
|
|
|
|
16 |
from PIL import Image
|
|
|
|
|
17 |
|
18 |
-
#
|
19 |
OCR_REPO = "rednote-hilab/dots.ocr"
|
|
|
20 |
CONVERT_REPO = "ZennyKenny/oss-20b-prereform-to-modern-ru-merged"
|
21 |
|
22 |
SYSTEM_MSG = (
|
23 |
"You convert Russian text from pre-1918 orthography to modern Russian spelling. "
|
24 |
"Keep wording and punctuation; change only orthography."
|
25 |
)
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
"
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
"
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
)
|
48 |
-
except Exception:
|
49 |
-
_state["ocr_prompt"] = (
|
50 |
-
"Extract the original text from this image as plain text. "
|
51 |
-
"Keep the reading order. Do not translate. Do not add extra formatting."
|
52 |
-
)
|
53 |
-
return _state["ocr_prompt"]
|
54 |
-
|
55 |
-
|
56 |
-
def _pick_attn_impl():
|
57 |
-
"""
|
58 |
-
Decide attention backend for OCR model.
|
59 |
-
If flash_attn can be imported successfully (and matches Torch/CUDA), use flash_attention_2.
|
60 |
-
Otherwise fall back to eager (most stable on Spaces/ZeroGPU).
|
61 |
-
"""
|
62 |
-
try:
|
63 |
-
import importlib
|
64 |
-
_ = importlib.import_module("flash_attn") # may raise
|
65 |
-
return "flash_attention_2"
|
66 |
-
except Exception:
|
67 |
-
return "eager"
|
68 |
-
|
69 |
-
|
70 |
-
def _ensure_models_on_gpu():
|
71 |
-
"""
|
72 |
-
Create/load models ONLY when on the GPU worker.
|
73 |
-
No torch/transformers imports at module scope.
|
74 |
-
"""
|
75 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
76 |
-
|
77 |
-
# OCR model/processor
|
78 |
-
if _state["ocr_model"] is None or _state["ocr_processor"] is None:
|
79 |
-
_state["ocr_model"] = AutoModelForCausalLM.from_pretrained(
|
80 |
-
OCR_REPO,
|
81 |
-
trust_remote_code=True,
|
82 |
-
attn_implementation=_pick_attn_impl(), # "eager" if flash-attn unavailable
|
83 |
-
device_map="auto",
|
84 |
-
torch_dtype="auto",
|
85 |
-
)
|
86 |
-
_state["ocr_processor"] = AutoProcessor.from_pretrained(
|
87 |
-
OCR_REPO, trust_remote_code=True
|
88 |
-
)
|
89 |
-
|
90 |
-
# Conversion model/tokenizer (pre-reform -> modern Russian)
|
91 |
-
if _state["conv_model"] is None or _state["conv_tok"] is None:
|
92 |
-
_state["conv_tok"] = AutoTokenizer.from_pretrained(CONVERT_REPO, use_fast=True)
|
93 |
-
_state["conv_model"] = AutoModelForCausalLM.from_pretrained(
|
94 |
-
CONVERT_REPO,
|
95 |
-
device_map="auto",
|
96 |
-
torch_dtype="auto",
|
97 |
-
)
|
98 |
-
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
ocr_prompt = _get_ocr_prompt()
|
109 |
|
110 |
-
|
|
|
|
|
|
|
111 |
messages = [
|
112 |
{
|
113 |
"role": "user",
|
114 |
"content": [
|
115 |
{"type": "image", "image": pil_image},
|
116 |
-
{"type": "text", "text":
|
117 |
],
|
118 |
}
|
119 |
]
|
120 |
-
|
121 |
-
|
122 |
-
text = ocr_processor.apply_chat_template(
|
123 |
-
messages, tokenize=False, add_generation_prompt=True
|
124 |
-
)
|
125 |
image_inputs, video_inputs = process_vision_info(messages)
|
126 |
-
|
127 |
-
inputs = ocr_processor(
|
128 |
text=[text],
|
129 |
images=image_inputs,
|
130 |
videos=video_inputs,
|
131 |
padding=True,
|
132 |
return_tensors="pt",
|
133 |
-
)
|
134 |
-
|
135 |
-
# Move to model device
|
136 |
-
dev = next(ocr_model.parameters()).device
|
137 |
-
inputs = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in inputs.items()}
|
138 |
|
139 |
-
# Generate
|
140 |
with torch.no_grad():
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
|
149 |
-
def
|
150 |
"""Use your merged model to convert pre-reform Russian -> modern Russian."""
|
151 |
-
import torch
|
152 |
-
|
153 |
-
conv_model = _state["conv_model"]
|
154 |
-
conv_tok = _state["conv_tok"]
|
155 |
-
|
156 |
messages = [
|
157 |
{"role": "system", "content": SYSTEM_MSG},
|
158 |
{"role": "user", "content": pre_reform_text},
|
159 |
]
|
160 |
-
prompt =
|
161 |
-
|
162 |
-
)
|
163 |
-
|
164 |
-
inputs = conv_tok([prompt], return_tensors="pt")
|
165 |
-
dev = next(conv_model.parameters()).device
|
166 |
-
inputs = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in inputs.items()}
|
167 |
|
168 |
with torch.no_grad():
|
169 |
-
gen =
|
170 |
**inputs,
|
171 |
max_new_tokens=1024,
|
172 |
-
do_sample=False,
|
173 |
temperature=0.0,
|
174 |
repetition_penalty=1.05,
|
175 |
)
|
176 |
-
|
177 |
gen_only = gen[0][inputs["input_ids"].shape[1]:]
|
178 |
-
return
|
179 |
|
180 |
|
181 |
-
@spaces.GPU() #
|
182 |
-
def transcribe_and_convert(
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
# 3) Markdown code block for easy copy
|
196 |
-
md = f"```text\n{modern_text}\n```"
|
197 |
-
|
198 |
-
return pil_image, ocr_text, modern_text, md
|
199 |
-
|
200 |
-
|
201 |
-
# ---------------- UI ----------------
|
202 |
with gr.Blocks(title="Pre-reform → Modern Russian (OCR + Conversion)") as demo:
|
203 |
gr.Markdown(
|
204 |
"## Pre-reform → Modern Russian (OCR + Conversion)\n"
|
205 |
-
"
|
206 |
-
"2) Click **Transcribe & Convert** — the app will OCR via `rednote-hilab/dots.ocr` and convert to modern spelling."
|
207 |
)
|
208 |
-
|
209 |
with gr.Row():
|
210 |
with gr.Column(scale=1):
|
211 |
-
|
212 |
run_btn = gr.Button("Transcribe & Convert", variant="primary")
|
213 |
-
gr.Markdown("Tip: higher-resolution images OCR better. For PDFs, export a page as an image.")
|
214 |
-
|
215 |
with gr.Column(scale=2):
|
216 |
with gr.Row():
|
217 |
-
|
218 |
ocr_box = gr.Textbox(label="Transcribed (pre-reform)", lines=14)
|
219 |
modern_box = gr.Textbox(label="Modern Russian", lines=14)
|
220 |
-
|
221 |
|
222 |
run_btn.click(
|
223 |
transcribe_and_convert,
|
224 |
-
inputs=[
|
225 |
-
outputs=[
|
226 |
api_name="transcribe_convert",
|
227 |
)
|
228 |
|
229 |
-
demo.queue(max_size=
|
|
|
1 |
# app.py
|
2 |
+
# Gradio + ZeroGPU: OCR pre-reform RU with dots.ocr -> convert to modern RU with your model.
|
3 |
+
# Same technique as the working Space you showed:
|
4 |
+
# - import `spaces` first
|
5 |
+
# - snapshot the OCR repo locally
|
6 |
+
# - load models at module scope (after spaces import)
|
7 |
+
# - use @spaces.GPU() for the heavy call
|
8 |
+
|
9 |
+
import spaces # must be first so ZeroGPU patches CUDA init correctly
|
10 |
|
11 |
import os
|
12 |
+
import traceback
|
13 |
+
from io import BytesIO
|
14 |
+
from typing import Tuple
|
15 |
|
|
|
16 |
import gradio as gr
|
17 |
+
import requests
|
18 |
+
import torch
|
19 |
+
from huggingface_hub import snapshot_download
|
20 |
from PIL import Image
|
21 |
+
from qwen_vl_utils import process_vision_info
|
22 |
+
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
23 |
|
24 |
+
# ---------- Config ----------
|
25 |
OCR_REPO = "rednote-hilab/dots.ocr"
|
26 |
+
OCR_LOCAL_DIR = "./models/dots-ocr-local" # local snapshot dir
|
27 |
CONVERT_REPO = "ZennyKenny/oss-20b-prereform-to-modern-ru-merged"
|
28 |
|
29 |
SYSTEM_MSG = (
|
30 |
"You convert Russian text from pre-1918 orthography to modern Russian spelling. "
|
31 |
"Keep wording and punctuation; change only orthography."
|
32 |
)
|
33 |
+
OCR_PROMPT = (
|
34 |
+
"Extract the original text from this image as plain text. "
|
35 |
+
"Keep the reading order. Do not translate. Do not add extra formatting."
|
36 |
+
)
|
37 |
|
38 |
+
# ---------- Utils ----------
|
39 |
+
def fetch_image(image_input) -> Image.Image:
|
40 |
+
"""Accept Gradio image (PIL) or URL/path string and return a PIL RGB image."""
|
41 |
+
if isinstance(image_input, Image.Image):
|
42 |
+
return image_input.convert("RGB")
|
43 |
+
if isinstance(image_input, str):
|
44 |
+
if image_input.startswith(("http://", "https://")):
|
45 |
+
resp = requests.get(image_input, timeout=30)
|
46 |
+
resp.raise_for_status()
|
47 |
+
return Image.open(BytesIO(resp.content)).convert("RGB")
|
48 |
+
return Image.open(image_input).convert("RGB")
|
49 |
+
raise ValueError(f"Unsupported image input: {type(image_input)}")
|
50 |
+
|
51 |
+
# ---------- Snapshot + load models at module scope (after spaces import) ----------
|
52 |
+
# Snapshot OCR model locally to avoid dynamic code churn and speed up cold starts.
|
53 |
+
snapshot_download(
|
54 |
+
repo_id=OCR_REPO,
|
55 |
+
local_dir=OCR_LOCAL_DIR,
|
56 |
+
local_dir_use_symlinks=False,
|
57 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
# Load OCR (tries flash-attn 2 path; if it's mismatched at runtime, you can switch to "eager")
|
60 |
+
_ocr_model = AutoModelForCausalLM.from_pretrained(
|
61 |
+
OCR_LOCAL_DIR,
|
62 |
+
attn_implementation="flash_attention_2", # matches the working Space technique
|
63 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
|
64 |
+
device_map="auto",
|
65 |
+
trust_remote_code=True,
|
66 |
+
)
|
67 |
+
_ocr_processor = AutoProcessor.from_pretrained(OCR_LOCAL_DIR, trust_remote_code=True)
|
68 |
+
|
69 |
+
# Load conversion model (pre-reform -> modern Russian)
|
70 |
+
_convert_tokenizer = AutoTokenizer.from_pretrained(CONVERT_REPO, use_fast=True)
|
71 |
+
_convert_model = AutoModelForCausalLM.from_pretrained(
|
72 |
+
CONVERT_REPO,
|
73 |
+
device_map="auto",
|
74 |
+
torch_dtype="auto",
|
75 |
+
)
|
76 |
|
77 |
+
# Device (safe after spaces import)
|
78 |
+
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
79 |
|
80 |
+
# ---------- Core pipeline ----------
|
81 |
+
def run_ocr(pil_image: Image.Image) -> str:
|
82 |
+
"""OCR using dots.ocr; returns plain text."""
|
83 |
+
# Build messages for OCR model
|
84 |
messages = [
|
85 |
{
|
86 |
"role": "user",
|
87 |
"content": [
|
88 |
{"type": "image", "image": pil_image},
|
89 |
+
{"type": "text", "text": OCR_PROMPT},
|
90 |
],
|
91 |
}
|
92 |
]
|
93 |
+
# Prepare inputs
|
94 |
+
text = _ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
|
|
|
95 |
image_inputs, video_inputs = process_vision_info(messages)
|
96 |
+
inputs = _ocr_processor(
|
|
|
97 |
text=[text],
|
98 |
images=image_inputs,
|
99 |
videos=video_inputs,
|
100 |
padding=True,
|
101 |
return_tensors="pt",
|
102 |
+
).to(_device)
|
|
|
|
|
|
|
|
|
103 |
|
|
|
104 |
with torch.no_grad():
|
105 |
+
generated_ids = _ocr_model.generate(
|
106 |
+
**inputs,
|
107 |
+
max_new_tokens=4096,
|
108 |
+
do_sample=False,
|
109 |
+
temperature=0.0,
|
110 |
+
)
|
111 |
+
# Trim prompt
|
112 |
+
trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)]
|
113 |
+
out_text = _ocr_processor.batch_decode(
|
114 |
+
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
115 |
+
)[0]
|
116 |
+
return (out_text or "").strip()
|
117 |
|
118 |
|
119 |
+
def convert_prereform_to_modern(pre_reform_text: str) -> str:
|
120 |
"""Use your merged model to convert pre-reform Russian -> modern Russian."""
|
|
|
|
|
|
|
|
|
|
|
121 |
messages = [
|
122 |
{"role": "system", "content": SYSTEM_MSG},
|
123 |
{"role": "user", "content": pre_reform_text},
|
124 |
]
|
125 |
+
prompt = _convert_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
126 |
+
inputs = _convert_tokenizer([prompt], return_tensors="pt").to(_convert_model.device)
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
with torch.no_grad():
|
129 |
+
gen = _convert_model.generate(
|
130 |
**inputs,
|
131 |
max_new_tokens=1024,
|
132 |
+
do_sample=False,
|
133 |
temperature=0.0,
|
134 |
repetition_penalty=1.05,
|
135 |
)
|
|
|
136 |
gen_only = gen[0][inputs["input_ids"].shape[1]:]
|
137 |
+
return _convert_tokenizer.decode(gen_only, skip_special_tokens=True).strip()
|
138 |
|
139 |
|
140 |
+
@spaces.GPU() # heavy work happens on ZeroGPU worker
|
141 |
+
def transcribe_and_convert(image_in) -> Tuple[Image.Image, str, str, str]:
|
142 |
+
try:
|
143 |
+
pil = fetch_image(image_in)
|
144 |
+
ocr_text = run_ocr(pil)
|
145 |
+
modern_text = convert_prereform_to_modern(ocr_text)
|
146 |
+
md_block = f"```text\n{modern_text}\n```"
|
147 |
+
return pil, ocr_text, modern_text, md_block
|
148 |
+
except Exception as e:
|
149 |
+
traceback.print_exc()
|
150 |
+
err = f"Error: {e}"
|
151 |
+
return None, "", "", err
|
152 |
+
|
153 |
+
# ---------- UI ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
with gr.Blocks(title="Pre-reform → Modern Russian (OCR + Conversion)") as demo:
|
155 |
gr.Markdown(
|
156 |
"## Pre-reform → Modern Russian (OCR + Conversion)\n"
|
157 |
+
"Upload an image containing pre-1918 Russian text → OCR via **dots.ocr** → convert to modern Russian."
|
|
|
158 |
)
|
|
|
159 |
with gr.Row():
|
160 |
with gr.Column(scale=1):
|
161 |
+
img_in = gr.Image(type="pil", label="Upload image (pre-reform Russian)")
|
162 |
run_btn = gr.Button("Transcribe & Convert", variant="primary")
|
|
|
|
|
163 |
with gr.Column(scale=2):
|
164 |
with gr.Row():
|
165 |
+
img_preview = gr.Image(label="Preview", interactive=False)
|
166 |
ocr_box = gr.Textbox(label="Transcribed (pre-reform)", lines=14)
|
167 |
modern_box = gr.Textbox(label="Modern Russian", lines=14)
|
168 |
+
md_box = gr.Markdown(label="Modern Russian (markdown code block)")
|
169 |
|
170 |
run_btn.click(
|
171 |
transcribe_and_convert,
|
172 |
+
inputs=[img_in],
|
173 |
+
outputs=[img_preview, ocr_box, modern_box, md_box],
|
174 |
api_name="transcribe_convert",
|
175 |
)
|
176 |
|
177 |
+
demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True)
|