import os import subprocess import spaces import torch import gradio as gr from gradio_client.client import DEFAULT_TEMP_DIR from playwright.sync_api import sync_playwright from threading import Thread from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from typing import List from PIL import Image from transformers.image_transforms import resize, to_channel_dimension_format # Install flash-attn without CUDA build isolation subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # Set the device to GPU if available, otherwise use CPU DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PROCESSOR = AutoProcessor.from_pretrained( "HuggingFaceM4/VLM_WebSight_finetuned", ) MODEL = AutoModelForCausalLM.from_pretrained( "HuggingFaceM4/VLM_WebSight_finetuned", trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ).to(DEVICE) # Determine image sequence length if MODEL.config.use_resampler: image_seq_len = MODEL.config.perceiver_config.resampler_n_latents else: image_seq_len = ( MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size ) ** 2 BOS_TOKEN = PROCESSOR.tokenizer.bos_token BAD_WORDS_IDS = PROCESSOR.tokenizer(["", ""], add_special_tokens=False).input_ids ## Utils def convert_to_rgb(image): if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite def custom_transform(x): x = convert_to_rgb(x) x = to_numpy_array(x) x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) x = PROCESSOR.image_processor.normalize( x, mean=PROCESSOR.image_processor.image_mean, std=PROCESSOR.image_processor.image_std ) x = to_channel_dimension_format(x, ChannelDimension.FIRST) x = torch.tensor(x) return x ## End of Utils # Install Playwright def install_playwright(): try: subprocess.run(["playwright", "install"], check=True) print("Playwright installation successful.") except subprocess.CalledProcessError as e: print(f"Error during Playwright installation: {e}") install_playwright() IMAGE_GALLERY_PATHS = [ f"example_images/{ex_image}" for ex_image in os.listdir(f"example_images") ] def add_file_gallery(selected_state: gr.SelectData, gallery_list: List[str]): return Image.open(gallery_list.root[selected_state.index].image.path) def render_webpage(html_css_code): with sync_playwright() as p: browser = p.chromium.launch(headless=True) context = browser.new_context( user_agent=( "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0" " Safari/537.36" ) ) page = context.new_page() page.set_content(html_css_code) page.wait_for_load_state("networkidle") output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png" _ = page.screenshot(path=output_path_screenshot, full_page=True) context.close() browser.close() return Image.open(output_path_screenshot) @spaces.GPU(duration=180) def model_inference(image): if image is None: raise ValueError("`image` is None. It should be a PIL image.") inputs = PROCESSOR.tokenizer( f"{BOS_TOKEN}{'' * image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = PROCESSOR.image_processor( [image], transform=custom_transform ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} streamer = TextIteratorStreamer( PROCESSOR.tokenizer, skip_prompt=True, ) generation_kwargs = dict( inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096, streamer=streamer, ) thread = Thread( target=MODEL.generate, kwargs=generation_kwargs, ) thread.start() generated_text = "" for new_text in streamer: if "" in new_text: new_text = new_text.replace("", "") rendered_image = render_webpage(generated_text) else: rendered_image = None generated_text += new_text yield generated_text, rendered_image generated_html = gr.Code(label="Extracted HTML", elem_id="generated_html") rendered_html = gr.Image(label="Rendered HTML", show_download_button=False, show_share_button=False) css = """ .gradio-container{max-width: 1000px!important} h1{display: flex;align-items: center;justify-content: center;gap: .25em} *{transition: width 0.5s ease, flex-grow 0.5s ease} """ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo: gr.Markdown( "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content." ) with gr.Row(equal_height=True): with gr.Column(scale=4, min_width=250) as upload_area: imagebox = gr.Image( type="pil", label="Screenshot to extract", visible=True, sources=["upload", "clipboard"], ) with gr.Group(): with gr.Row(): submit_btn = gr.Button(value="▶️ Submit", visible=True, min_width=120) clear_btn = gr.ClearButton( [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120 ) regenerate_btn = gr.Button(value="🔄 Regenerate", visible=True, min_width=120) with gr.Column(scale=4): rendered_html.render() with gr.Row(): generated_html.render() with gr.Row(): template_gallery = gr.Gallery( value=IMAGE_GALLERY_PATHS, label="Templates Gallery", allow_preview=False, columns=5, elem_id="gallery", show_share_button=False, height=400, ) gr.on( triggers=[imagebox.upload, submit_btn.click, regenerate_btn.click], fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) regenerate_btn.click( fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) template_gallery.select( fn=add_file_gallery, inputs=[template_gallery], outputs=[imagebox], ).success( fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) demo.load() demo.queue(max_size=40, api_open=False) demo.launch(max_threads=400)