hysts HF Staff commited on
Commit
7252e54
·
1 Parent(s): 6c2dda3
Files changed (6) hide show
  1. README.md +2 -2
  2. app.py +122 -3
  3. packages.txt +1 -0
  4. requirements.txt +2 -1
  5. scheduler.py +31 -0
  6. style.css +6 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Base Space
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
@@ -7,7 +7,7 @@ sdk: gradio
7
  sdk_version: 3.34.0
8
  python_version: 3.10.11
9
  app_file: app.py
10
- pinned: true
11
  license: mit
12
  duplicated_from: hysts-samples/base-space
13
  ---
 
1
  ---
2
+ title: Save user preferences
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
 
7
  sdk_version: 3.34.0
8
  python_version: 3.10.11
9
  app_file: app.py
10
+ pinned: false
11
  license: mit
12
  duplicated_from: hysts-samples/base-space
13
  ---
app.py CHANGED
@@ -1,7 +1,126 @@
1
  #!/usr/bin/env python
2
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- with gr.Blocks() as demo:
6
- pass
7
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python
2
 
3
+ import datetime
4
+ import json
5
+ import os
6
+ import pathlib
7
+ import shutil
8
+ import tempfile
9
+ import uuid
10
+ from typing import Any
11
+
12
  import gradio as gr
13
+ from gradio_client import Client
14
+
15
+ from scheduler import ZipScheduler
16
+
17
+ HF_TOKEN = os.getenv('HF_TOKEN')
18
+ UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID')
19
+ UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5'))
20
+ USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1'
21
+ LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results'))
22
+ LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
23
+
24
+ scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID,
25
+ repo_type='dataset',
26
+ every=UPLOAD_FREQUENCY,
27
+ private=not USE_PUBLIC_REPO,
28
+ token=HF_TOKEN,
29
+ folder_path=LOCAL_SAVE_DIR)
30
+
31
+ client = Client('stabilityai/stable-diffusion')
32
+
33
+
34
+ def generate(prompt: str) -> tuple[str, list[str]]:
35
+ negative_prompt = ''
36
+ guidance_scale = 9
37
+ out_dir = client.predict(prompt,
38
+ negative_prompt,
39
+ guidance_scale,
40
+ fn_index=1)
41
+
42
+ config = {
43
+ 'prompt': prompt,
44
+ 'negative_prompt': negative_prompt,
45
+ 'guidance_scale': guidance_scale,
46
+ }
47
+ config_file = tempfile.NamedTemporaryFile(mode='w',
48
+ suffix='.json',
49
+ delete=False)
50
+ json.dump(config, config_file)
51
+
52
+ with open(pathlib.Path(out_dir) / 'captions.json') as f:
53
+ paths = list(json.load(f).keys())
54
+ return config_file.name, paths
55
+
56
+
57
+ def get_selected_index(evt: gr.SelectData) -> int:
58
+ return evt.index
59
+
60
+
61
+ def save_preference(config_path: str, gallery: list[dict[str, Any]],
62
+ selected_index: int) -> None:
63
+ save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}'
64
+ save_dir.mkdir(parents=True, exist_ok=True)
65
+
66
+ paths = [x['name'] for x in gallery]
67
+ with scheduler.lock:
68
+ for index, path in enumerate(paths):
69
+ ext = pathlib.Path(path).suffix
70
+ shutil.move(path, save_dir / f'{index:03d}{ext}')
71
+
72
+ with open(config_path) as f:
73
+ config = json.load(f)
74
+ json_path = save_dir / 'preferences.json'
75
+ with json_path.open('w') as f:
76
+ preferences = config | {
77
+ 'selected_index': selected_index,
78
+ 'timestamp': datetime.datetime.utcnow().isoformat(),
79
+ }
80
+ json.dump(preferences, f)
81
+
82
+
83
+ def clear() -> tuple[dict, dict, dict]:
84
+ return (
85
+ gr.update(value=None),
86
+ gr.update(value=None),
87
+ gr.update(interactive=False),
88
+ )
89
+
90
+
91
+ with gr.Blocks(css='style.css') as demo:
92
+ with gr.Group():
93
+ prompt = gr.Text(show_label=False, placeholder='Prompt')
94
+ config_path = gr.Text(visible=False)
95
+ gallery = gr.Gallery(show_label=False).style(columns=2,
96
+ rows=2,
97
+ height='600px',
98
+ object_fit='scale-down')
99
+ selected_index = gr.Number(visible=False, precision=0)
100
+ save_preference_button = gr.Button('Save preference', interactive=False)
101
+
102
+ prompt.submit(
103
+ fn=generate,
104
+ inputs=prompt,
105
+ outputs=[config_path, gallery],
106
+ ).success(
107
+ fn=lambda: gr.update(interactive=True),
108
+ outputs=save_preference_button,
109
+ queue=False,
110
+ )
111
 
112
+ gallery.select(
113
+ fn=get_selected_index,
114
+ outputs=selected_index,
115
+ queue=False,
116
+ )
117
+ save_preference_button.click(
118
+ fn=save_preference,
119
+ inputs=[config_path, gallery, selected_index],
120
+ queue=False,
121
+ ).then(
122
+ fn=clear,
123
+ outputs=[config_path, gallery, save_preference_button],
124
+ queue=False,
125
+ )
126
+ demo.queue(concurrency_count=5).launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tar
requirements.txt CHANGED
@@ -1 +1,2 @@
1
-
 
 
1
+ git+https://github.com/huggingface/huggingface_hub@extendable-commit-scheduler
2
+ gradio_client==0.2.6
scheduler.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import shlex
3
+ import shutil
4
+ import subprocess
5
+ import tempfile
6
+ import uuid
7
+
8
+ from huggingface_hub import CommitScheduler
9
+
10
+
11
+ class ZipScheduler(CommitScheduler):
12
+ def push_to_hub(self):
13
+ with self.lock:
14
+ paths = sorted(self.folder_path.glob('*'))
15
+ if not paths:
16
+ return
17
+ with tempfile.TemporaryDirectory() as tmpdir:
18
+ archive_path = pathlib.Path(tmpdir) / f'{uuid.uuid4()}.tar.gz'
19
+ subprocess.run(shlex.split(
20
+ f'tar czf {archive_path} {self.folder_path.name}'),
21
+ cwd=self.folder_path.parent)
22
+ self.api.upload_file(
23
+ repo_id=self.repo_id,
24
+ repo_type=self.repo_type,
25
+ revision=self.revision,
26
+ path_in_repo=archive_path.name,
27
+ path_or_fileobj=archive_path,
28
+ token=self.token,
29
+ )
30
+ shutil.rmtree(self.folder_path, ignore_errors=True)
31
+ self.folder_path.mkdir(parents=True, exist_ok=True)
style.css CHANGED
@@ -1,3 +1,9 @@
1
  h1 {
2
  text-align: center;
3
  }
 
 
 
 
 
 
 
1
  h1 {
2
  text-align: center;
3
  }
4
+
5
+ #component-0 {
6
+ max-width: 800px;
7
+ margin: auto;
8
+ padding-top: 1.5rem;
9
+ }