File size: 1,905 Bytes
efdce2e
 
116a187
026edab
116a187
 
 
 
026edab
116a187
aa5635b
 
 
efdce2e
026edab
efdce2e
 
026edab
 
 
aa5635b
116a187
 
 
 
 
 
 
 
 
 
 
 
 
c4af1bb
aa5635b
efdce2e
 
 
 
116a187
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from os import abort, getenv
from sys import exception
from textwrap import dedent
import traceback

import gradio as gr
from torch import cuda

from detikzify.webui import BANNER, build_ui, make_light

def is_official_demo():
    return getenv("SPACE_AUTHOR_NAME") == "nllg"

def add_abort_hook(func, *errors):
    def wrapper(*args, **kwargs):
        if isinstance(exception(), errors):
            abort()
        return func(*args, **kwargs)
    return wrapper

if is_official_demo() and not cuda.is_available():
    center = ".gradio-container {text-align: center}"
    with gr.Blocks(css=center, theme=make_light(gr.themes.Soft()), title="DeTikZify") as demo:
        badge = "https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg"
        link = "https://huggingface.co/spaces/nllg/DeTikZify?duplicate=true"
        html = f'<a style="display:inline-block" href="{link}"> <img src="{badge}" alt="Duplicate this Space"> </a>'
        message = dedent("""\
        The resources required by our models surpass those provided by Hugging
        Face Spaces' free CPU tier. For full functionality, we suggest
        duplicating this space using a paid private GPU runtime.
        """)
        gr.HTML(f'{BANNER}\n<p>{message}</p>\n{html}')
else:
    use_big_models = cuda.is_available() and cuda.get_device_properties(0).total_memory > 15835660288
    model = f"detikzify-{'v2-8b' if use_big_models else 'ds-1.3b'}"
    demo = build_ui(lock=is_official_demo(), model=model, light=True).queue()
    # Hack to temporarily work around memory leak, see:
    #   * https://huggingface.co/spaces/nllg/DeTikZify/discussions/2
    #   * https://github.com/gradio-app/gradio/issues/8503
    traceback.print_exc = add_abort_hook(traceback.print_exc, MemoryError, cuda.OutOfMemoryError)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)