import gradio as gr from datasets import concatenate_datasets from huggingface_hub import HfApi from huggingface_hub.errors import HFValidationError from requests.exceptions import HTTPError from transformer_ranker import Result from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory from transformer_ranker.embedder import Embedder import math DISABLED_BUTTON_VARIANT = "huggingface" ENABLED_BUTTON_VARIANT = "primary" HEADLINE = """

TransformerRanker

A very simple library that helps you find the best-suited language model for your NLP task. All you need to do is to select a dataset and a list of pre-trained language models (LMs) from the 🤗 HuggingFace Hub. TransformerRanker will quickly estimate which of these LMs will perform best on the given dataset!

GitHub Badge Package Badge Tutorials Badge License: MIT

Developed at Humboldt University of Berlin.

""" FOOTER = """ **Note:** This demonstration currently runs on a CPU and is suited for smaller models only. **Developers:** [@plonerma](https://huggingface.co/plonerma) and [@lukasgarbas](https://huggingface.co/lukasgarbas). For feedback, suggestions, or contributions, reach out via GitHub or leave a message in the [discussions](https://huggingface.co/spaces/lukasgarbas/transformer-ranker/discussions). """ CSS = """ .gradio-container{max-width: 800px !important} a {color: #ff9d00;} @media (prefers-color-scheme: dark) { a {color: #be185d;} } """ hf_api = HfApi() def check_dataset_exists(dataset_name): """Update loading button if dataset can be found""" try: hf_api.dataset_info(dataset_name) return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT) except (HTTPError, HFValidationError): return gr.update(value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT) def check_dataset_is_loaded(dataset, text_column, label_column, task_category): if dataset and text_column != "-" and label_column != "-" and task_category != "-": return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT) else: return gr.update(interactive=False, variant=DISABLED_BUTTON_VARIANT) def get_dataset_info(dataset): """Show information for dataset settings""" joined_dataset = concatenate_datasets(list(dataset.values())) datacleaner = DatasetCleaner() try: text_column = datacleaner._find_column(joined_dataset, "text column") except ValueError: gr.Warning("Text column can not be found. Select it in the dataset settings.") text_column = "-" try: label_column = datacleaner._find_column(joined_dataset, "label column") except ValueError: gr.Warning("Label column can not be found. Select it in the dataset settings.") label_column = "-" task_category = "-" if label_column != "-": try: # Find or set the task_category task_category = datacleaner._find_task_category(joined_dataset, label_column) except ValueError: gr.Warning( "Task category could not be determined. The dataset must support classification or regression tasks.", ) pass num_samples = len(joined_dataset) return ( gr.update( value=task_category, choices=[str(t) for t in TaskCategory], interactive=True, ), gr.update( value=text_column, choices=joined_dataset.column_names, interactive=True ), gr.update( value="-", choices=["-", *joined_dataset.column_names], interactive=True ), gr.update( value=label_column, choices=joined_dataset.column_names, interactive=True ), num_samples, ) def compute_ratio(num_samples_to_use, num_samples): if num_samples > 0: return num_samples_to_use / num_samples else: return 0.0 def ensure_one_lm_selected(checkbox_values, previous_values): if not any(checkbox_values): return previous_values return checkbox_values # Apply monkey patch to enable callbacks _old_embed = Embedder.embed def _new_embed(embedder, sentences, batch_size: int = 32, **kw): if embedder.tracker is not None: embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size)) return _old_embed(embedder, sentences, batch_size=batch_size, **kw) Embedder.embed = _new_embed _old_embed_batch = Embedder.embed_batch def _new_embed_batch(embedder, *args, **kw): r = _old_embed_batch(embedder, *args, **kw) if embedder.tracker is not None: embedder.tracker.update_batch_complete() return r Embedder.embed_batch = _new_embed_batch _old_init = Embedder.__init__ def _new_init(embedder, *args, tracker=None, **kw): _old_init(embedder, *args, **kw) embedder.tracker = tracker Embedder.__init__ = _new_init class EmbeddingProgressTracker: def __init__(self, *, progress, model_names): self.model_names = model_names self.progress_bar = progress @property def total(self): return len(self.model_names) def __enter__(self): self.progress_bar = gr.Progress(track_tqdm=False) self.current_model = -1 self.batches_complete = 0 self.batches_total = None return self def __exit__(self, typ, value, tb): if typ is None: self.progress_bar(1.0, desc="Done") else: self.progress_bar(1.0, desc="Error") # Do not suppress any errors return False def update_num_batches(self, total): self.current_model += 1 self.batches_complete = 0 self.batches_total = total self.update_bar() def update_batch_complete(self): self.batches_complete += 1 self.update_bar() def update_bar(self): i = self.current_model description = f"Running {self.model_names[i]} ({i + 1} / {self.total})" progress = i / self.total if self.batches_total is not None: progress += (self.batches_complete / self.batches_total) / self.total self.progress_bar(progress=progress, desc=description)