Spaces:
Runtime error
Runtime error
Update
Browse files- README.md +2 -2
- app.py +122 -3
- packages.txt +1 -0
- requirements.txt +2 -1
- scheduler.py +31 -0
- style.css +6 -0
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: ⚡
|
4 |
colorFrom: red
|
5 |
colorTo: purple
|
@@ -7,7 +7,7 @@ sdk: gradio
|
|
7 |
sdk_version: 3.34.0
|
8 |
python_version: 3.10.11
|
9 |
app_file: app.py
|
10 |
-
pinned:
|
11 |
license: mit
|
12 |
duplicated_from: hysts-samples/base-space
|
13 |
---
|
|
|
1 |
---
|
2 |
+
title: Save user preferences
|
3 |
emoji: ⚡
|
4 |
colorFrom: red
|
5 |
colorTo: purple
|
|
|
7 |
sdk_version: 3.34.0
|
8 |
python_version: 3.10.11
|
9 |
app_file: app.py
|
10 |
+
pinned: false
|
11 |
license: mit
|
12 |
duplicated_from: hysts-samples/base-space
|
13 |
---
|
app.py
CHANGED
@@ -1,7 +1,126 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#!/usr/bin/env python
|
2 |
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import pathlib
|
7 |
+
import shutil
|
8 |
+
import tempfile
|
9 |
+
import uuid
|
10 |
+
from typing import Any
|
11 |
+
|
12 |
import gradio as gr
|
13 |
+
from gradio_client import Client
|
14 |
+
|
15 |
+
from scheduler import ZipScheduler
|
16 |
+
|
17 |
+
HF_TOKEN = os.getenv('HF_TOKEN')
|
18 |
+
UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID')
|
19 |
+
UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5'))
|
20 |
+
USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1'
|
21 |
+
LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results'))
|
22 |
+
LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
|
23 |
+
|
24 |
+
scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID,
|
25 |
+
repo_type='dataset',
|
26 |
+
every=UPLOAD_FREQUENCY,
|
27 |
+
private=not USE_PUBLIC_REPO,
|
28 |
+
token=HF_TOKEN,
|
29 |
+
folder_path=LOCAL_SAVE_DIR)
|
30 |
+
|
31 |
+
client = Client('stabilityai/stable-diffusion')
|
32 |
+
|
33 |
+
|
34 |
+
def generate(prompt: str) -> tuple[str, list[str]]:
|
35 |
+
negative_prompt = ''
|
36 |
+
guidance_scale = 9
|
37 |
+
out_dir = client.predict(prompt,
|
38 |
+
negative_prompt,
|
39 |
+
guidance_scale,
|
40 |
+
fn_index=1)
|
41 |
+
|
42 |
+
config = {
|
43 |
+
'prompt': prompt,
|
44 |
+
'negative_prompt': negative_prompt,
|
45 |
+
'guidance_scale': guidance_scale,
|
46 |
+
}
|
47 |
+
config_file = tempfile.NamedTemporaryFile(mode='w',
|
48 |
+
suffix='.json',
|
49 |
+
delete=False)
|
50 |
+
json.dump(config, config_file)
|
51 |
+
|
52 |
+
with open(pathlib.Path(out_dir) / 'captions.json') as f:
|
53 |
+
paths = list(json.load(f).keys())
|
54 |
+
return config_file.name, paths
|
55 |
+
|
56 |
+
|
57 |
+
def get_selected_index(evt: gr.SelectData) -> int:
|
58 |
+
return evt.index
|
59 |
+
|
60 |
+
|
61 |
+
def save_preference(config_path: str, gallery: list[dict[str, Any]],
|
62 |
+
selected_index: int) -> None:
|
63 |
+
save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}'
|
64 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
65 |
+
|
66 |
+
paths = [x['name'] for x in gallery]
|
67 |
+
with scheduler.lock:
|
68 |
+
for index, path in enumerate(paths):
|
69 |
+
ext = pathlib.Path(path).suffix
|
70 |
+
shutil.move(path, save_dir / f'{index:03d}{ext}')
|
71 |
+
|
72 |
+
with open(config_path) as f:
|
73 |
+
config = json.load(f)
|
74 |
+
json_path = save_dir / 'preferences.json'
|
75 |
+
with json_path.open('w') as f:
|
76 |
+
preferences = config | {
|
77 |
+
'selected_index': selected_index,
|
78 |
+
'timestamp': datetime.datetime.utcnow().isoformat(),
|
79 |
+
}
|
80 |
+
json.dump(preferences, f)
|
81 |
+
|
82 |
+
|
83 |
+
def clear() -> tuple[dict, dict, dict]:
|
84 |
+
return (
|
85 |
+
gr.update(value=None),
|
86 |
+
gr.update(value=None),
|
87 |
+
gr.update(interactive=False),
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
with gr.Blocks(css='style.css') as demo:
|
92 |
+
with gr.Group():
|
93 |
+
prompt = gr.Text(show_label=False, placeholder='Prompt')
|
94 |
+
config_path = gr.Text(visible=False)
|
95 |
+
gallery = gr.Gallery(show_label=False).style(columns=2,
|
96 |
+
rows=2,
|
97 |
+
height='600px',
|
98 |
+
object_fit='scale-down')
|
99 |
+
selected_index = gr.Number(visible=False, precision=0)
|
100 |
+
save_preference_button = gr.Button('Save preference', interactive=False)
|
101 |
+
|
102 |
+
prompt.submit(
|
103 |
+
fn=generate,
|
104 |
+
inputs=prompt,
|
105 |
+
outputs=[config_path, gallery],
|
106 |
+
).success(
|
107 |
+
fn=lambda: gr.update(interactive=True),
|
108 |
+
outputs=save_preference_button,
|
109 |
+
queue=False,
|
110 |
+
)
|
111 |
|
112 |
+
gallery.select(
|
113 |
+
fn=get_selected_index,
|
114 |
+
outputs=selected_index,
|
115 |
+
queue=False,
|
116 |
+
)
|
117 |
+
save_preference_button.click(
|
118 |
+
fn=save_preference,
|
119 |
+
inputs=[config_path, gallery, selected_index],
|
120 |
+
queue=False,
|
121 |
+
).then(
|
122 |
+
fn=clear,
|
123 |
+
outputs=[config_path, gallery, save_preference_button],
|
124 |
+
queue=False,
|
125 |
+
)
|
126 |
+
demo.queue(concurrency_count=5).launch()
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
tar
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
|
|
|
|
1 |
+
git+https://github.com/huggingface/huggingface_hub@extendable-commit-scheduler
|
2 |
+
gradio_client==0.2.6
|
scheduler.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import shlex
|
3 |
+
import shutil
|
4 |
+
import subprocess
|
5 |
+
import tempfile
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
from huggingface_hub import CommitScheduler
|
9 |
+
|
10 |
+
|
11 |
+
class ZipScheduler(CommitScheduler):
|
12 |
+
def push_to_hub(self):
|
13 |
+
with self.lock:
|
14 |
+
paths = sorted(self.folder_path.glob('*'))
|
15 |
+
if not paths:
|
16 |
+
return
|
17 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
18 |
+
archive_path = pathlib.Path(tmpdir) / f'{uuid.uuid4()}.tar.gz'
|
19 |
+
subprocess.run(shlex.split(
|
20 |
+
f'tar czf {archive_path} {self.folder_path.name}'),
|
21 |
+
cwd=self.folder_path.parent)
|
22 |
+
self.api.upload_file(
|
23 |
+
repo_id=self.repo_id,
|
24 |
+
repo_type=self.repo_type,
|
25 |
+
revision=self.revision,
|
26 |
+
path_in_repo=archive_path.name,
|
27 |
+
path_or_fileobj=archive_path,
|
28 |
+
token=self.token,
|
29 |
+
)
|
30 |
+
shutil.rmtree(self.folder_path, ignore_errors=True)
|
31 |
+
self.folder_path.mkdir(parents=True, exist_ok=True)
|
style.css
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
h1 {
|
2 |
text-align: center;
|
3 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
h1 {
|
2 |
text-align: center;
|
3 |
}
|
4 |
+
|
5 |
+
#component-0 {
|
6 |
+
max-width: 800px;
|
7 |
+
margin: auto;
|
8 |
+
padding-top: 1.5rem;
|
9 |
+
}
|