Spaces:
Running
Running
""" | |
Beautiful Medical NER Demo using OpenMed Models | |
A comprehensive Named Entity Recognition demo for medical professionals | |
featuring multiple specialized medical models with beautiful entity visualization. | |
""" | |
import gradio as gr | |
import spacy | |
from spacy import displacy | |
from transformers import pipeline | |
import warnings | |
import logging | |
import re | |
from typing import Dict, List, Tuple | |
import random | |
# Suppress warnings for cleaner output | |
warnings.filterwarnings("ignore") | |
logging.getLogger("transformers").setLevel(logging.ERROR) | |
# Model configurations | |
MODELS = { | |
"Oncology Detection": { | |
"model_id": "OpenMed/OpenMed-NER-OncologyDetect-SuperMedical-355M", | |
"description": "Specialized in cancer, genetics, and oncology entities", | |
}, | |
# "Pharmaceutical Detection": { | |
# "model_id": "OpenMed/OpenMed-NER-PharmaDetect-SuperClinical-434M", | |
# "description": "Detects drugs, chemicals, and pharmaceutical entities", | |
# }, | |
# "Disease Detection": { | |
# "model_id": "OpenMed/OpenMed-NER-DiseaseDetect-SuperClinical-434M", | |
# "description": "Identifies diseases, conditions, and pathologies", | |
# }, | |
# "Genome Detection": { | |
# "model_id": "OpenMed/OpenMed-NER-GenomeDetect-ModernClinical-395M", | |
# "description": "Recognizes genes, proteins, and genomic entities", | |
# }, | |
} | |
# Medical text examples for each model | |
EXAMPLES = { | |
"Oncology Detection": [ | |
"The patient presented with metastatic adenocarcinoma of the lung with mutations in EGFR and KRAS genes. Treatment with erlotinib was initiated, targeting the epidermal growth factor receptor pathway.", | |
"Histological examination revealed invasive ductal carcinoma with high-grade nuclear features. The tumor showed positive estrogen receptor and HER2 amplification, indicating potential for targeted therapy.", | |
"The oncologist recommended adjuvant chemotherapy with doxorubicin and cyclophosphamide, followed by paclitaxel, to target rapidly dividing cancer cells in the breast tissue.", | |
], | |
"Pharmaceutical Detection": [ | |
"The patient was prescribed metformin 500mg twice daily for diabetes management, along with lisinopril 10mg for hypertension control and atorvastatin 20mg for cholesterol reduction.", | |
"Administration of morphine sulfate provided effective pain relief, while ondansetron prevented chemotherapy-induced nausea. The patient also received dexamethasone as an anti-inflammatory agent.", | |
"The pharmacokinetic study evaluated the absorption of ibuprofen and its interaction with warfarin, monitoring plasma concentrations and potential bleeding risks.", | |
], | |
"Disease Detection": [ | |
"The patient was diagnosed with type 2 diabetes mellitus, hypertension, and coronary artery disease. Additional findings included diabetic nephropathy and peripheral neuropathy.", | |
"Clinical presentation was consistent with acute myocardial infarction complicated by cardiogenic shock. The patient also had a history of chronic obstructive pulmonary disease and atrial fibrillation.", | |
"Laboratory results confirmed the diagnosis of rheumatoid arthritis with elevated inflammatory markers. The patient also exhibited symptoms of Sjögren's syndrome and osteoporosis.", | |
], | |
"Genome Detection": [ | |
"Genetic analysis revealed mutations in the BRCA1 and BRCA2 genes, significantly increasing the risk of hereditary breast and ovarian cancer. The p53 tumor suppressor gene also showed alterations.", | |
"Expression profiling identified upregulation of MYC oncogene and downregulation of PTEN tumor suppressor. The mTOR signaling pathway showed significant activation in the tumor samples.", | |
"Whole genome sequencing detected variants in CFTR gene associated with cystic fibrosis, along with polymorphisms in CYP2D6 affecting drug metabolism and APOE influencing Alzheimer's risk.", | |
], | |
} | |
def ner_filtered(text, *, pipe, min_score=0.60, min_length=1, remove_punctuation=True): | |
""" | |
Apply confidence and punctuation filtering to NER pipeline results. | |
This is the proven filtering approach that eliminates spurious predictions. | |
""" | |
# 1️⃣ Run the NER model | |
raw_entities = pipe(text) | |
# 2️⃣ Define regex for content detection | |
if remove_punctuation: | |
has_content = re.compile(r"[A-Za-z0-9]") # At least one letter or digit | |
else: | |
has_content = re.compile(r".") # Allow everything | |
# 3️⃣ Apply filters | |
filtered_entities = [] | |
for entity in raw_entities: | |
# Confidence filter | |
if entity["score"] < min_score: | |
continue | |
# Length filter | |
if len(entity["word"].strip()) < min_length: | |
continue | |
# Punctuation filter | |
if remove_punctuation and not has_content.search(entity["word"]): | |
continue | |
filtered_entities.append(entity) | |
return filtered_entities | |
def advanced_ner_filter(text, *, pipe, min_score=0.60, strip_edges=True, exclude_patterns=None): | |
""" | |
Advanced filtering with edge stripping and pattern exclusion. | |
""" | |
entities = pipe(text) | |
filtered = [] | |
for entity in entities: | |
if entity["score"] < min_score: | |
continue | |
word = entity["word"] | |
# Strip punctuation from edges | |
if strip_edges: | |
stripped = word.strip(".,!?;:()[]{}\"'-_") | |
if not stripped: | |
continue | |
entity = entity.copy() | |
entity["word"] = stripped | |
# Apply exclusion patterns | |
if exclude_patterns: | |
skip = any(re.match(pattern, entity["word"]) for pattern in exclude_patterns) | |
if skip: | |
continue | |
# Only keep entities with actual content | |
if re.search(r"[A-Za-z0-9]", entity["word"]): | |
filtered.append(entity) | |
return filtered | |
def merge_adjacent_entities(entities, original_text, max_gap=10): | |
""" | |
Merge adjacent entities of the same type that are separated by small gaps. | |
Useful for handling cases like "BRCA1 and BRCA2" or "HER2-positive". | |
""" | |
if len(entities) < 2: | |
return entities | |
merged = [] | |
current = entities[0].copy() | |
for next_entity in entities[1:]: | |
# Check if same entity type and close proximity | |
if (current["entity_group"] == next_entity["entity_group"] and | |
next_entity["start"] - current["end"] <= max_gap): | |
# Check what's between them | |
gap_text = original_text[current["end"]:next_entity["start"]] | |
# Merge if gap contains only connecting words/punctuation | |
if re.match(r"^[\s\-,/and]*$", gap_text.lower()): | |
# Extend current entity to include the next one | |
current["word"] = original_text[current["start"]:next_entity["end"]] | |
current["end"] = next_entity["end"] | |
current["score"] = (current["score"] + next_entity["score"]) / 2 | |
continue | |
# No merge, add current and move to next | |
merged.append(current) | |
current = next_entity.copy() | |
# Don't forget the last entity | |
merged.append(current) | |
return merged | |
class MedicalNERApp: | |
def __init__(self): | |
self.pipelines = {} | |
self.nlp = spacy.blank("en") # SpaCy model for visualization | |
self.load_models() | |
def load_models(self): | |
"""Load and cache all models with proper aggregation strategy""" | |
print("🏥 Loading Medical NER Models...") | |
for model_name, config in MODELS.items(): | |
print(f"Loading {model_name}...") | |
try: | |
# Use aggregation_strategy=None and handle grouping ourselves for better control | |
ner_pipeline = pipeline( | |
"token-classification", | |
model=config["model_id"], | |
aggregation_strategy=None, # ← Get raw tokens, group them properly ourselves | |
device=0 if __name__ == "__main__" else -1 # Use GPU if available | |
) | |
self.pipelines[model_name] = ner_pipeline | |
print(f"✅ {model_name} loaded successfully with custom entity grouping") | |
except Exception as e: | |
print(f"❌ Error loading {model_name}: {str(e)}") | |
self.pipelines[model_name] = None | |
print("🎉 All models loaded and cached!") | |
def smart_group_entities(self, tokens, text): | |
""" | |
Smart entity grouping that properly merges sub-tokens into complete entities. | |
This fixes the issue where aggregation_strategy="simple" creates overlapping spans. | |
""" | |
if not tokens: | |
return [] | |
entities = [] | |
current_entity = None | |
for token in tokens: | |
label = token['entity'] | |
score = token['score'] | |
word = token['word'] | |
start = token['start'] | |
end = token['end'] | |
# Skip O (Outside) tags | |
if label == 'O': | |
if current_entity: | |
entities.append(current_entity) | |
current_entity = None | |
continue | |
# Clean the label (remove B- and I- prefixes) | |
clean_label = label.replace('B-', '').replace('I-', '') | |
# Start new entity (B- tag or different entity type) | |
if label.startswith('B-') or (current_entity and current_entity['entity_group'] != clean_label): | |
if current_entity: | |
entities.append(current_entity) | |
current_entity = { | |
'entity_group': clean_label, | |
'score': score, | |
'word': text[start:end], # Use actual text from the source | |
'start': start, | |
'end': end | |
} | |
# Continue current entity (I- tag) | |
elif current_entity and clean_label == current_entity['entity_group']: | |
# Extend the current entity | |
current_entity['end'] = end | |
current_entity['word'] = text[current_entity['start']:end] | |
current_entity['score'] = (current_entity['score'] + score) / 2 # Average scores | |
# Don't forget the last entity | |
if current_entity: | |
entities.append(current_entity) | |
return entities | |
def create_spacy_visualization(self, text: str, entities: List[Dict], model_name: str) -> str: | |
"""Create spaCy displaCy visualization with dynamic colors and improved span handling.""" | |
print(f"\n🔍 VISUALIZATION DEBUG for {model_name}") | |
print(f"Input text length: {len(text)} chars") | |
print(f"Total entities to visualize: {len(entities)}") | |
# Show all entities found | |
print("\n📋 ENTITIES TO VISUALIZE:") | |
entity_by_type = {} | |
for i, ent in enumerate(entities): | |
entity_type = ent['entity_group'] | |
if entity_type not in entity_by_type: | |
entity_by_type[entity_type] = [] | |
entity_by_type[entity_type].append(ent) | |
print(f" {i+1:2d}. [{ent['start']:3d}:{ent['end']:3d}] '{ent['word']:25}' -> {entity_type:20} (score: {ent['score']:.3f})") | |
print(f"\n📊 ENTITY COUNTS BY TYPE:") | |
for entity_type, ents in entity_by_type.items(): | |
print(f" {entity_type}: {len(ents)} instances") | |
doc = self.nlp(text) | |
spacy_ents = [] | |
failed_entities = [] | |
print(f"\n🔧 CREATING SPACY SPANS:") | |
for i, entity in enumerate(entities): | |
try: | |
start = entity['start'] | |
end = entity['end'] | |
label = entity['entity_group'] | |
entity_text = entity['word'] | |
print(f" {i+1:2d}. Trying span [{start}:{end}] '{entity_text}' -> {label}") | |
# Try to create span with default mode first | |
span = doc.char_span(start, end, label=label) | |
if span is not None: | |
spacy_ents.append(span) | |
print(f" ✅ SUCCESS: '{span.text}' -> {label}") | |
else: | |
# Try different alignment modes | |
span = doc.char_span(start, end, label=label, alignment_mode="expand") | |
if span is not None: | |
spacy_ents.append(span) | |
print(f" ✅ SUCCESS (expand): '{span.text}' -> {label}") | |
else: | |
failed_entities.append(entity) | |
print(f" ❌ FAILED: Could not create span for '{entity_text}' -> {label}") | |
except Exception as e: | |
failed_entities.append(entity) | |
print(f" 💥 EXCEPTION: {str(e)}") | |
print(f"\n📈 SPAN CREATION RESULTS:") | |
print(f" ✅ Successful spans: {len(spacy_ents)}") | |
print(f" ❌ Failed spans: {len(failed_entities)}") | |
# Filter overlapping spans (this is much cleaner now) | |
print(f"\n🔄 FILTERING OVERLAPPING SPANS...") | |
print(f" Before filtering: {len(spacy_ents)} spans") | |
spacy_ents = spacy.util.filter_spans(spacy_ents) | |
print(f" After filtering: {len(spacy_ents)} spans") | |
doc.ents = spacy_ents | |
print(f"\n🎨 FINAL VISUALIZATION ENTITIES:") | |
for ent in doc.ents: | |
print(f" '{ent.text}' ({ent.label_}) [{ent.start_char}:{ent.end_char}]") | |
# Define color palette | |
color_palette = { | |
"DISEASE": "#FF5733", | |
"CHEM": "#33FF57", | |
"GENE/PROTEIN": "#3357FF", | |
"Cancer": "#FF33F6", | |
"Cell": "#33FFF6", | |
"Organ": "#F6FF33", | |
"Tissue": "#FF8333", | |
"Simple_chemical": "#8333FF", | |
"Gene_or_gene_product": "#33FF83", | |
"Organism": "#FF6B33", | |
} | |
unique_labels = sorted(list(set(ent.label_ for ent in doc.ents))) | |
colors = {} | |
for label in unique_labels: | |
if label in color_palette: | |
colors[label] = color_palette[label] | |
else: | |
colors[label] = "#" + ''.join([hex(x)[2:].zfill(2) for x in (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))]) | |
options = { | |
"ents": unique_labels, | |
"colors": colors, | |
"style": "max-width: 100%; line-height: 2.5; direction: ltr;" | |
} | |
print(f"\n🎨 VISUALIZATION CONFIG:") | |
print(f" Entity types for display: {unique_labels}") | |
print(f" Color mapping: {colors}") | |
# Add debug info to the HTML output if there are issues | |
debug_info = "" | |
if failed_entities: | |
debug_info = f""" | |
<div style="margin-top: 15px; padding: 10px; background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; font-size: 12px;"> | |
<strong>⚠️ Visualization Info:</strong><br> | |
{len(failed_entities)} entities could not be visualized due to text alignment issues.<br> | |
All entities are still counted in the summary below. | |
</div> | |
""" | |
displacy_html = displacy.render(doc, style="ent", options=options, page=False) | |
return displacy_html + debug_info | |
def predict_entities(self, text: str, model_name: str, confidence_threshold: float = 0.60) -> Tuple[str, str]: | |
""" | |
Predict entities using smart grouping for maximum accuracy. | |
""" | |
if not text.strip(): | |
return "<p>Please enter medical text to analyze.</p>", "No text provided" | |
if model_name not in self.pipelines or self.pipelines[model_name] is None: | |
return f"<p>❌ Model {model_name} is not available.</p>", "Model not available" | |
try: | |
print(f"\nDEBUG: Processing text with {model_name}") | |
print(f"Text: {text}") | |
print(f"Confidence threshold: {confidence_threshold}") | |
# Get raw token predictions from the pipeline | |
pipeline_instance = self.pipelines[model_name] | |
raw_tokens = pipeline_instance(text) | |
print(f"Got {len(raw_tokens)} raw tokens from pipeline") | |
if not raw_tokens: | |
return "<p>No entities detected.</p>", "No entities found" | |
# Use our smart grouping to merge sub-tokens into complete entities | |
grouped_entities = self.smart_group_entities(raw_tokens, text) | |
print(f"Smart grouping created {len(grouped_entities)} entities") | |
# Apply confidence filtering to the grouped entities | |
filtered_entities = [] | |
for entity in grouped_entities: | |
if entity["score"] >= confidence_threshold: | |
# Apply additional quality filters | |
if (len(entity["word"].strip()) > 0 and # Not empty | |
re.search(r"[A-Za-z0-9]", entity["word"])): # Contains actual content | |
filtered_entities.append(entity) | |
print(f"✅ After confidence filtering: {len(filtered_entities)} high-quality entities") | |
if not filtered_entities: | |
return f"<p>No entities found with confidence ≥ {confidence_threshold:.0%}. Try lowering the threshold.</p>", "No entities found" | |
# Create visualization and summary | |
html_output = self.create_spacy_visualization(text, filtered_entities, model_name) | |
wrapped_html = self.wrap_displacy_output(html_output, model_name, len(filtered_entities), confidence_threshold) | |
summary = self.create_summary(filtered_entities, model_name, confidence_threshold) | |
return wrapped_html, summary | |
except Exception as e: | |
import traceback | |
print(f"ERROR in predict_entities: {str(e)}") | |
traceback.print_exc() | |
error_msg = f"Error during prediction: {str(e)}" | |
return f"<p>❌ {error_msg}</p>", error_msg | |
def wrap_displacy_output(self, displacy_html: str, model_name: str, entity_count: int, confidence_threshold: float) -> str: | |
"""Wrap displaCy output in a beautiful container with filtering info.""" | |
return f""" | |
<div style="font-family: 'Segoe UI', Arial, sans-serif; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
overflow: hidden;"> | |
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; padding: 15px; text-align: center;"> | |
<h3 style="margin: 0; font-size: 18px;">{model_name}</h3> | |
<p style="margin: 5px 0 0 0; opacity: 0.9; font-size: 14px;"> | |
Found {entity_count} high-confidence medical entities (≥{confidence_threshold:.0%}) | |
</p> | |
<div style="margin-top: 8px; font-size: 12px; opacity: 0.8;"> | |
✅ Filtered with aggregation_strategy="simple" + confidence threshold | |
</div> | |
</div> | |
<div style="padding: 20px; margin: 0; line-height: 2.5;"> | |
{displacy_html} | |
</div> | |
</div> | |
""" | |
def create_summary(self, entities: List[Dict], model_name: str, confidence_threshold: float) -> str: | |
"""Create a summary of detected entities with filtering info.""" | |
if not entities: | |
return "No entities detected." | |
entity_counts = {} | |
for entity in entities: | |
label = entity["entity_group"] | |
if label not in entity_counts: | |
entity_counts[label] = [] | |
entity_counts[label].append(entity) | |
summary_parts = [f"📊 **{model_name} Analysis Results**\n"] | |
summary_parts.append(f"**Total high-confidence entities**: {len(entities)} (threshold ≥{confidence_threshold:.0%})\n") | |
for label, ents in sorted(entity_counts.items()): | |
avg_confidence = sum(e["score"] for e in ents) / len(ents) | |
unique_texts = sorted(list(set(e["word"] for e in ents))) | |
summary_parts.append( | |
f"• **{label}**: {len(ents)} instances " | |
f"(avg confidence: {avg_confidence:.2f})\n" | |
f" Examples: {', '.join(unique_texts[:3])}" | |
f"{'...' if len(unique_texts) > 3 else ''}\n" | |
) | |
# Add filtering information | |
summary_parts.append("\n🎯 **Accuracy Improvements Applied**\n") | |
summary_parts.append("✅ Smart BIO token grouping - Properly merges sub-tokens into complete entities\n") | |
summary_parts.append(f"✅ Confidence threshold filtering - Only entities ≥ {confidence_threshold:.0%} confidence\n") | |
summary_parts.append("✅ Content validation - Excludes empty or punctuation-only predictions\n") | |
summary_parts.append("✅ Precise span alignment - Improved text-to-visual mapping\n") | |
# Add model information | |
summary_parts.append(f"\n🔬 **Model Information**\n") | |
summary_parts.append(f"Model: `{MODELS[model_name]['model_id']}`\n") | |
summary_parts.append(f"Description: {MODELS[model_name]['description']}\n") | |
return "\n".join(summary_parts) | |
# Initialize the app | |
print("🚀 Initializing Medical NER Application...") | |
ner_app = MedicalNERApp() | |
# Warmup | |
print("🔥 Warming up models...") | |
warmup_text = "The patient has diabetes and takes metformin." | |
for model_name in MODELS.keys(): | |
if ner_app.pipelines[model_name] is not None: | |
try: | |
print(f"Warming up {model_name}...") | |
_ = ner_app.predict_entities(warmup_text, model_name, 0.60) | |
print(f"✅ {model_name} warmed up successfully") | |
except Exception as e: | |
print(f"⚠️ Warmup failed for {model_name}: {str(e)}") | |
print("🎉 Model warmup complete!") | |
def predict_wrapper(text: str, model_name: str, confidence_threshold: float): | |
"""Wrapper function for Gradio interface with confidence control""" | |
html_output, summary = ner_app.predict_entities(text, model_name, confidence_threshold) | |
return html_output, summary | |
def load_example(model_name: str, example_idx: int): | |
"""Load example text for the selected model""" | |
if model_name in EXAMPLES and 0 <= example_idx < len(EXAMPLES[model_name]): | |
return EXAMPLES[model_name][example_idx] | |
return "" | |
# Create Gradio interface | |
with gr.Blocks( | |
title="🏥 Medical NER Expert", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
.main-header { | |
text-align: center; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
padding: 2rem; | |
border-radius: 15px; | |
margin-bottom: 2rem; | |
box-shadow: 0 8px 32px rgba(0,0,0,0.1); | |
} | |
.model-info { | |
padding: 1rem; | |
border-radius: 10px; | |
border-left: 4px solid #667eea; | |
margin: 1rem 0; | |
} | |
.accuracy-badge { | |
background: #28a745; | |
color: white; | |
padding: 4px 8px; | |
border-radius: 12px; | |
font-size: 12px; | |
font-weight: bold; | |
} | |
""", | |
) as demo: | |
# Header | |
gr.HTML( | |
""" | |
<div class="main-header"> | |
<h1>🏥 Medical NER Expert</h1> | |
<p>Advanced Named Entity Recognition for Medical Professionals</p> | |
<div style="margin-top: 10px;"> | |
<span class="accuracy-badge">✅ HIGH ACCURACY MODE</span> | |
</div> | |
<p style="font-size: 14px; margin-top: 10px; opacity: 0.9;"> | |
Powered by OpenMed models + proven filtering techniques (aggregation_strategy="simple" + confidence thresholds) | |
</p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Model selection | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value="Oncology Detection", | |
label="🔬 Select Medical NER Model", | |
info="Choose the specialized model for your analysis", | |
) | |
# Model info display | |
model_info = gr.HTML( | |
value=f""" | |
<div class="model-info"> | |
<strong>Oncology Detection</strong><br> | |
{MODELS["Oncology Detection"]["description"]} | |
</div> | |
""" | |
) | |
# Confidence threshold slider | |
confidence_slider = gr.Slider( | |
minimum=0.30, | |
maximum=0.95, | |
value=0.60, | |
step=0.05, | |
label="🎯 Confidence Threshold", | |
info="Higher values = fewer but more confident predictions" | |
) | |
# Text input | |
text_input = gr.Textbox( | |
lines=8, | |
placeholder="Enter medical text here for entity recognition...", | |
label="📝 Medical Text Input", | |
value=EXAMPLES["Oncology Detection"][0], | |
) | |
# Example buttons | |
with gr.Row(): | |
example_buttons = [] | |
for i in range(3): | |
btn = gr.Button(f"Example {i+1}", size="sm", variant="secondary") | |
example_buttons.append(btn) | |
# Analyze button | |
analyze_btn = gr.Button("🔍 Analyze Text", variant="primary", size="lg") | |
with gr.Column(scale=3): | |
# Results | |
results_html = gr.HTML( | |
label="🎯 Entity Recognition Results", | |
value="<p>Select a model and enter text to see entity recognition results.</p>", | |
) | |
# Summary | |
summary_output = gr.Markdown( | |
value="Analysis summary will appear here...", | |
label="📊 Analysis Summary", | |
) | |
# Update model info when model changes | |
def update_model_info(model_name): | |
if model_name in MODELS: | |
return f""" | |
<div class="model-info"> | |
<strong>{model_name}</strong><br> | |
{MODELS[model_name]["description"]}<br> | |
<small>Model: {MODELS[model_name]["model_id"]}</small> | |
</div> | |
""" | |
return "" | |
model_dropdown.change( | |
update_model_info, inputs=[model_dropdown], outputs=[model_info] | |
) | |
# Example button handlers | |
for i, btn in enumerate(example_buttons): | |
btn.click( | |
lambda model_name, idx=i: load_example(model_name, idx), | |
inputs=[model_dropdown], | |
outputs=[text_input], | |
) | |
# Main analysis function | |
analyze_btn.click( | |
predict_wrapper, | |
inputs=[text_input, model_dropdown, confidence_slider], | |
outputs=[results_html, summary_output], | |
) | |
# Auto-update when model changes (load first example) | |
model_dropdown.change( | |
lambda model_name: load_example(model_name, 0), | |
inputs=[model_dropdown], | |
outputs=[text_input], | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
share=False, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |