ZennyKenny commited on
Commit
818e485
·
verified ·
1 Parent(s): 129f697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -161
app.py CHANGED
@@ -1,229 +1,177 @@
1
  # app.py
2
- # ZeroGPU-friendly Gradio app:
3
- # 1) Upload image with pre-reform Russian.
4
- # 2) OCR via rednote-hilab/dots.ocr.
5
- # 3) Convert to modern Russian via your HF model.
6
- # Notes:
7
- # - Import `spaces` FIRST and avoid any CUDA/tensor ops at module import.
8
- # - All torch/transformers/qwen_vl_utils imports happen INSIDE the @spaces.GPU() path.
9
- # - Attn impl defaults to "eager" (no flash-attn required). If flash_attn is present & compatible, we'll use it.
10
 
11
  import os
12
- os.environ.setdefault("PYTORCH_NVML_BASED_CUDA_CHECK", "0") # avoid NVML probe before ZeroGPU init
 
 
13
 
14
- import spaces # MUST be imported before anything that may touch CUDA
15
  import gradio as gr
 
 
 
16
  from PIL import Image
 
 
17
 
18
- # --- Repos & constants ---
19
  OCR_REPO = "rednote-hilab/dots.ocr"
 
20
  CONVERT_REPO = "ZennyKenny/oss-20b-prereform-to-modern-ru-merged"
21
 
22
  SYSTEM_MSG = (
23
  "You convert Russian text from pre-1918 orthography to modern Russian spelling. "
24
  "Keep wording and punctuation; change only orthography."
25
  )
 
 
 
 
26
 
27
- # --- Lazy state (populated on first GPU call) ---
28
- _state = {
29
- "ocr_model": None,
30
- "ocr_processor": None,
31
- "conv_model": None,
32
- "conv_tok": None,
33
- "ocr_prompt": None,
34
- }
35
-
36
-
37
- def _get_ocr_prompt():
38
- """Fetch OCR text-extraction prompt from dots.ocr utils if available, else fallback."""
39
- if _state["ocr_prompt"] is not None:
40
- return _state["ocr_prompt"]
41
- try:
42
- # Import lazily to avoid early CUDA init
43
- from dots_ocr.utils import dict_promptmode_to_prompt # type: ignore
44
- _state["ocr_prompt"] = dict_promptmode_to_prompt().get("prompt_ocr") or (
45
- "Extract the original text from this image as plain text. "
46
- "Keep the reading order. Do not translate. Do not add extra formatting."
47
- )
48
- except Exception:
49
- _state["ocr_prompt"] = (
50
- "Extract the original text from this image as plain text. "
51
- "Keep the reading order. Do not translate. Do not add extra formatting."
52
- )
53
- return _state["ocr_prompt"]
54
-
55
-
56
- def _pick_attn_impl():
57
- """
58
- Decide attention backend for OCR model.
59
- If flash_attn can be imported successfully (and matches Torch/CUDA), use flash_attention_2.
60
- Otherwise fall back to eager (most stable on Spaces/ZeroGPU).
61
- """
62
- try:
63
- import importlib
64
- _ = importlib.import_module("flash_attn") # may raise
65
- return "flash_attention_2"
66
- except Exception:
67
- return "eager"
68
-
69
-
70
- def _ensure_models_on_gpu():
71
- """
72
- Create/load models ONLY when on the GPU worker.
73
- No torch/transformers imports at module scope.
74
- """
75
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
76
-
77
- # OCR model/processor
78
- if _state["ocr_model"] is None or _state["ocr_processor"] is None:
79
- _state["ocr_model"] = AutoModelForCausalLM.from_pretrained(
80
- OCR_REPO,
81
- trust_remote_code=True,
82
- attn_implementation=_pick_attn_impl(), # "eager" if flash-attn unavailable
83
- device_map="auto",
84
- torch_dtype="auto",
85
- )
86
- _state["ocr_processor"] = AutoProcessor.from_pretrained(
87
- OCR_REPO, trust_remote_code=True
88
- )
89
-
90
- # Conversion model/tokenizer (pre-reform -> modern Russian)
91
- if _state["conv_model"] is None or _state["conv_tok"] is None:
92
- _state["conv_tok"] = AutoTokenizer.from_pretrained(CONVERT_REPO, use_fast=True)
93
- _state["conv_model"] = AutoModelForCausalLM.from_pretrained(
94
- CONVERT_REPO,
95
- device_map="auto",
96
- torch_dtype="auto",
97
- )
98
-
99
 
100
- def _run_ocr_on_gpu(pil_image: Image.Image) -> str:
101
- """Run dots.ocr on the given image and return raw OCR text."""
102
- # Heavy imports inside GPU context
103
- import torch
104
- from qwen_vl_utils import process_vision_info
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- ocr_model = _state["ocr_model"]
107
- ocr_processor = _state["ocr_processor"]
108
- ocr_prompt = _get_ocr_prompt()
109
 
110
- # Build chat-style message with image + text
 
 
 
111
  messages = [
112
  {
113
  "role": "user",
114
  "content": [
115
  {"type": "image", "image": pil_image},
116
- {"type": "text", "text": ocr_prompt},
117
  ],
118
  }
119
  ]
120
-
121
- # Apply the processor's chat template and package inputs
122
- text = ocr_processor.apply_chat_template(
123
- messages, tokenize=False, add_generation_prompt=True
124
- )
125
  image_inputs, video_inputs = process_vision_info(messages)
126
-
127
- inputs = ocr_processor(
128
  text=[text],
129
  images=image_inputs,
130
  videos=video_inputs,
131
  padding=True,
132
  return_tensors="pt",
133
- )
134
-
135
- # Move to model device
136
- dev = next(ocr_model.parameters()).device
137
- inputs = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in inputs.items()}
138
 
139
- # Generate
140
  with torch.no_grad():
141
- gen_ids = ocr_model.generate(**inputs, max_new_tokens=2048)
142
- prompt_len = inputs["input_ids"].shape[1]
143
- out_ids = gen_ids[0][prompt_len:]
144
- text_out = ocr_processor.decode(out_ids, skip_special_tokens=True).strip()
145
-
146
- return text_out
 
 
 
 
 
 
147
 
148
 
149
- def _convert_on_gpu(pre_reform_text: str) -> str:
150
  """Use your merged model to convert pre-reform Russian -> modern Russian."""
151
- import torch
152
-
153
- conv_model = _state["conv_model"]
154
- conv_tok = _state["conv_tok"]
155
-
156
  messages = [
157
  {"role": "system", "content": SYSTEM_MSG},
158
  {"role": "user", "content": pre_reform_text},
159
  ]
160
- prompt = conv_tok.apply_chat_template(
161
- messages, tokenize=False, add_generation_prompt=True
162
- )
163
-
164
- inputs = conv_tok([prompt], return_tensors="pt")
165
- dev = next(conv_model.parameters()).device
166
- inputs = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in inputs.items()}
167
 
168
  with torch.no_grad():
169
- gen = conv_model.generate(
170
  **inputs,
171
  max_new_tokens=1024,
172
- do_sample=False, # deterministic for orthography conversion
173
  temperature=0.0,
174
  repetition_penalty=1.05,
175
  )
176
-
177
  gen_only = gen[0][inputs["input_ids"].shape[1]:]
178
- return conv_tok.decode(gen_only, skip_special_tokens=True).strip()
179
 
180
 
181
- @spaces.GPU() # ZeroGPU entrypoint: all CUDA must happen inside here (or helpers it calls)
182
- def transcribe_and_convert(pil_image: Image.Image):
183
- if pil_image is None:
184
- return None, "", "", "Please upload an image."
185
-
186
- # Lazily load models on the GPU worker
187
- _ensure_models_on_gpu()
188
-
189
- # 1) OCR
190
- ocr_text = _run_ocr_on_gpu(pil_image)
191
-
192
- # 2) Convert pre-reform -> modern Russian
193
- modern_text = _convert_on_gpu(ocr_text)
194
-
195
- # 3) Markdown code block for easy copy
196
- md = f"```text\n{modern_text}\n```"
197
-
198
- return pil_image, ocr_text, modern_text, md
199
-
200
-
201
- # ---------------- UI ----------------
202
  with gr.Blocks(title="Pre-reform → Modern Russian (OCR + Conversion)") as demo:
203
  gr.Markdown(
204
  "## Pre-reform → Modern Russian (OCR + Conversion)\n"
205
- "1) Upload an image containing pre-1918 Russian text.\n"
206
- "2) Click **Transcribe & Convert** — the app will OCR via `rednote-hilab/dots.ocr` and convert to modern spelling."
207
  )
208
-
209
  with gr.Row():
210
  with gr.Column(scale=1):
211
- image_in = gr.Image(type="pil", label="Upload image")
212
  run_btn = gr.Button("Transcribe & Convert", variant="primary")
213
- gr.Markdown("Tip: higher-resolution images OCR better. For PDFs, export a page as an image.")
214
-
215
  with gr.Column(scale=2):
216
  with gr.Row():
217
- image_preview = gr.Image(label="Preview", interactive=False)
218
  ocr_box = gr.Textbox(label="Transcribed (pre-reform)", lines=14)
219
  modern_box = gr.Textbox(label="Modern Russian", lines=14)
220
- md_block = gr.Markdown(label="Modern Russian (markdown code block)")
221
 
222
  run_btn.click(
223
  transcribe_and_convert,
224
- inputs=[image_in],
225
- outputs=[image_preview, ocr_box, modern_box, md_block],
226
  api_name="transcribe_convert",
227
  )
228
 
229
- demo.queue(max_size=16).launch()
 
1
  # app.py
2
+ # Gradio + ZeroGPU: OCR pre-reform RU with dots.ocr -> convert to modern RU with your model.
3
+ # Same technique as the working Space you showed:
4
+ # - import `spaces` first
5
+ # - snapshot the OCR repo locally
6
+ # - load models at module scope (after spaces import)
7
+ # - use @spaces.GPU() for the heavy call
8
+
9
+ import spaces # must be first so ZeroGPU patches CUDA init correctly
10
 
11
  import os
12
+ import traceback
13
+ from io import BytesIO
14
+ from typing import Tuple
15
 
 
16
  import gradio as gr
17
+ import requests
18
+ import torch
19
+ from huggingface_hub import snapshot_download
20
  from PIL import Image
21
+ from qwen_vl_utils import process_vision_info
22
+ from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
23
 
24
+ # ---------- Config ----------
25
  OCR_REPO = "rednote-hilab/dots.ocr"
26
+ OCR_LOCAL_DIR = "./models/dots-ocr-local" # local snapshot dir
27
  CONVERT_REPO = "ZennyKenny/oss-20b-prereform-to-modern-ru-merged"
28
 
29
  SYSTEM_MSG = (
30
  "You convert Russian text from pre-1918 orthography to modern Russian spelling. "
31
  "Keep wording and punctuation; change only orthography."
32
  )
33
+ OCR_PROMPT = (
34
+ "Extract the original text from this image as plain text. "
35
+ "Keep the reading order. Do not translate. Do not add extra formatting."
36
+ )
37
 
38
+ # ---------- Utils ----------
39
+ def fetch_image(image_input) -> Image.Image:
40
+ """Accept Gradio image (PIL) or URL/path string and return a PIL RGB image."""
41
+ if isinstance(image_input, Image.Image):
42
+ return image_input.convert("RGB")
43
+ if isinstance(image_input, str):
44
+ if image_input.startswith(("http://", "https://")):
45
+ resp = requests.get(image_input, timeout=30)
46
+ resp.raise_for_status()
47
+ return Image.open(BytesIO(resp.content)).convert("RGB")
48
+ return Image.open(image_input).convert("RGB")
49
+ raise ValueError(f"Unsupported image input: {type(image_input)}")
50
+
51
+ # ---------- Snapshot + load models at module scope (after spaces import) ----------
52
+ # Snapshot OCR model locally to avoid dynamic code churn and speed up cold starts.
53
+ snapshot_download(
54
+ repo_id=OCR_REPO,
55
+ local_dir=OCR_LOCAL_DIR,
56
+ local_dir_use_symlinks=False,
57
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Load OCR (tries flash-attn 2 path; if it's mismatched at runtime, you can switch to "eager")
60
+ _ocr_model = AutoModelForCausalLM.from_pretrained(
61
+ OCR_LOCAL_DIR,
62
+ attn_implementation="flash_attention_2", # matches the working Space technique
63
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
64
+ device_map="auto",
65
+ trust_remote_code=True,
66
+ )
67
+ _ocr_processor = AutoProcessor.from_pretrained(OCR_LOCAL_DIR, trust_remote_code=True)
68
+
69
+ # Load conversion model (pre-reform -> modern Russian)
70
+ _convert_tokenizer = AutoTokenizer.from_pretrained(CONVERT_REPO, use_fast=True)
71
+ _convert_model = AutoModelForCausalLM.from_pretrained(
72
+ CONVERT_REPO,
73
+ device_map="auto",
74
+ torch_dtype="auto",
75
+ )
76
 
77
+ # Device (safe after spaces import)
78
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
 
79
 
80
+ # ---------- Core pipeline ----------
81
+ def run_ocr(pil_image: Image.Image) -> str:
82
+ """OCR using dots.ocr; returns plain text."""
83
+ # Build messages for OCR model
84
  messages = [
85
  {
86
  "role": "user",
87
  "content": [
88
  {"type": "image", "image": pil_image},
89
+ {"type": "text", "text": OCR_PROMPT},
90
  ],
91
  }
92
  ]
93
+ # Prepare inputs
94
+ text = _ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
95
  image_inputs, video_inputs = process_vision_info(messages)
96
+ inputs = _ocr_processor(
 
97
  text=[text],
98
  images=image_inputs,
99
  videos=video_inputs,
100
  padding=True,
101
  return_tensors="pt",
102
+ ).to(_device)
 
 
 
 
103
 
 
104
  with torch.no_grad():
105
+ generated_ids = _ocr_model.generate(
106
+ **inputs,
107
+ max_new_tokens=4096,
108
+ do_sample=False,
109
+ temperature=0.0,
110
+ )
111
+ # Trim prompt
112
+ trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)]
113
+ out_text = _ocr_processor.batch_decode(
114
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
115
+ )[0]
116
+ return (out_text or "").strip()
117
 
118
 
119
+ def convert_prereform_to_modern(pre_reform_text: str) -> str:
120
  """Use your merged model to convert pre-reform Russian -> modern Russian."""
 
 
 
 
 
121
  messages = [
122
  {"role": "system", "content": SYSTEM_MSG},
123
  {"role": "user", "content": pre_reform_text},
124
  ]
125
+ prompt = _convert_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
126
+ inputs = _convert_tokenizer([prompt], return_tensors="pt").to(_convert_model.device)
 
 
 
 
 
127
 
128
  with torch.no_grad():
129
+ gen = _convert_model.generate(
130
  **inputs,
131
  max_new_tokens=1024,
132
+ do_sample=False,
133
  temperature=0.0,
134
  repetition_penalty=1.05,
135
  )
 
136
  gen_only = gen[0][inputs["input_ids"].shape[1]:]
137
+ return _convert_tokenizer.decode(gen_only, skip_special_tokens=True).strip()
138
 
139
 
140
+ @spaces.GPU() # heavy work happens on ZeroGPU worker
141
+ def transcribe_and_convert(image_in) -> Tuple[Image.Image, str, str, str]:
142
+ try:
143
+ pil = fetch_image(image_in)
144
+ ocr_text = run_ocr(pil)
145
+ modern_text = convert_prereform_to_modern(ocr_text)
146
+ md_block = f"```text\n{modern_text}\n```"
147
+ return pil, ocr_text, modern_text, md_block
148
+ except Exception as e:
149
+ traceback.print_exc()
150
+ err = f"Error: {e}"
151
+ return None, "", "", err
152
+
153
+ # ---------- UI ----------
 
 
 
 
 
 
 
154
  with gr.Blocks(title="Pre-reform → Modern Russian (OCR + Conversion)") as demo:
155
  gr.Markdown(
156
  "## Pre-reform → Modern Russian (OCR + Conversion)\n"
157
+ "Upload an image containing pre-1918 Russian text → OCR via **dots.ocr** → convert to modern Russian."
 
158
  )
 
159
  with gr.Row():
160
  with gr.Column(scale=1):
161
+ img_in = gr.Image(type="pil", label="Upload image (pre-reform Russian)")
162
  run_btn = gr.Button("Transcribe & Convert", variant="primary")
 
 
163
  with gr.Column(scale=2):
164
  with gr.Row():
165
+ img_preview = gr.Image(label="Preview", interactive=False)
166
  ocr_box = gr.Textbox(label="Transcribed (pre-reform)", lines=14)
167
  modern_box = gr.Textbox(label="Modern Russian", lines=14)
168
+ md_box = gr.Markdown(label="Modern Russian (markdown code block)")
169
 
170
  run_btn.click(
171
  transcribe_and_convert,
172
+ inputs=[img_in],
173
+ outputs=[img_preview, ocr_box, modern_box, md_box],
174
  api_name="transcribe_convert",
175
  )
176
 
177
+ demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True)