Spaces:
Runtime error
Runtime error
Commit
Β·
b58f9ec
1
Parent(s):
7d65544
initial commit
Browse files- README.md +1 -1
- app.py +57 -0
- config.py +7 -0
- requirements.txt +3 -0
- utils.py +168 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Role Play Crowdsource Signup
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Role Play Crowdsource Signup
|
3 |
+
emoji: π₯
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import argilla as rg
|
5 |
+
from huggingface_hub import login
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
import config
|
9 |
+
import utils
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
records = init()
|
14 |
+
|
15 |
+
with gr.Blocks() as demo:
|
16 |
+
gr.Markdown("# Role-Play Crowdsource\n"
|
17 |
+
"TODO")
|
18 |
+
username = gr.Textbox(label="Username", placeholder="alekseykorshuk")
|
19 |
+
password = gr.Textbox(label="Password", placeholder="12345678")
|
20 |
+
btn = gr.Button("Run")
|
21 |
+
status = gr.Textbox(label="Status")
|
22 |
+
|
23 |
+
btn.click(
|
24 |
+
fn=partial(signup, records=records),
|
25 |
+
inputs=[username, password], outputs=status)
|
26 |
+
demo.launch()
|
27 |
+
|
28 |
+
|
29 |
+
def init():
|
30 |
+
rg.init(
|
31 |
+
api_url=config.api_url,
|
32 |
+
api_key=config.api_key
|
33 |
+
)
|
34 |
+
login(config.hf_token)
|
35 |
+
records = utils.get_records()
|
36 |
+
return records
|
37 |
+
|
38 |
+
|
39 |
+
def signup(username, password, records):
|
40 |
+
inputs_correctness = utils.check_inputs(username, password)
|
41 |
+
if inputs_correctness:
|
42 |
+
return inputs_correctness
|
43 |
+
user = utils.get_user(username, password)
|
44 |
+
response = utils.authorize_user(username, password)
|
45 |
+
if response.status_code != 200:
|
46 |
+
return "Unable to authorize, please check your credentials."
|
47 |
+
|
48 |
+
workspace = utils.add_workspace(user)
|
49 |
+
num_datasets = utils.get_num_datasets(user)
|
50 |
+
records_to_add = utils.get_records_to_add(user, records, num_datasets)
|
51 |
+
dataset_name = utils.push_dataset(num_datasets, workspace, records_to_add)
|
52 |
+
response_text = utils.get_response_message(username, password, dataset_name)
|
53 |
+
return response_text
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
main()
|
config.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
samples_per_group = int(os.environ.get("SAMPLES_PER_GROUP", 250))
|
4 |
+
api_url = os.environ.get("ARGILLA_API_URL")
|
5 |
+
api_key = os.environ.get("ARGILLA_API_KEY")
|
6 |
+
hf_dataset_path = os.environ.get("HF_DATASET_PATH")
|
7 |
+
hf_token = os.environ.get("HF_TOKEN")
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.24.1
|
2 |
+
datasets==2.11.0
|
3 |
+
argilla==1.11.0
|
utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import requests
|
4 |
+
import tqdm
|
5 |
+
from datasets import load_dataset
|
6 |
+
import argilla as rg
|
7 |
+
|
8 |
+
import config
|
9 |
+
|
10 |
+
|
11 |
+
def push_dataset(num_datasets, workspace, records_to_add):
|
12 |
+
dataset = rg.FeedbackDataset(
|
13 |
+
guidelines=_get_guidelines(),
|
14 |
+
fields=_get_fields(),
|
15 |
+
questions=_get_questions()
|
16 |
+
)
|
17 |
+
dataset.add_records(records_to_add)
|
18 |
+
dataset_name = get_dataset_name(num_datasets)
|
19 |
+
dataset.push_to_argilla(name=dataset_name, workspace=workspace.name, show_progress=True)
|
20 |
+
return dataset_name
|
21 |
+
|
22 |
+
|
23 |
+
def _get_fields():
|
24 |
+
fields = [
|
25 |
+
rg.TextField(name="system", title="Character description"),
|
26 |
+
rg.TextField(name="conversation_history", title="Conversation history"),
|
27 |
+
]
|
28 |
+
return fields
|
29 |
+
|
30 |
+
|
31 |
+
def get_records():
|
32 |
+
dataset = _get_dataset()
|
33 |
+
records = [
|
34 |
+
rg.FeedbackRecord(
|
35 |
+
fields={
|
36 |
+
"system": record["system"],
|
37 |
+
"conversation_history": record["conversation_history"],
|
38 |
+
},
|
39 |
+
external_id=record['external_id']
|
40 |
+
)
|
41 |
+
for record in tqdm.tqdm(dataset)
|
42 |
+
]
|
43 |
+
return records
|
44 |
+
|
45 |
+
|
46 |
+
def _get_dataset():
|
47 |
+
dataset = load_dataset(config.hf_dataset_path, split="train")
|
48 |
+
return dataset
|
49 |
+
|
50 |
+
|
51 |
+
def _get_questions():
|
52 |
+
questions = [
|
53 |
+
rg.TextQuestion(
|
54 |
+
name="new-response",
|
55 |
+
title="Character response:",
|
56 |
+
description="Write the final version of the Character response, making sure that it matches the character "
|
57 |
+
"description and makes sense for the conversation history.",
|
58 |
+
required=True
|
59 |
+
)
|
60 |
+
]
|
61 |
+
return questions
|
62 |
+
|
63 |
+
|
64 |
+
def _get_guidelines():
|
65 |
+
guidelines = None
|
66 |
+
return guidelines
|
67 |
+
|
68 |
+
|
69 |
+
def authorize_user(username, password):
|
70 |
+
data = {
|
71 |
+
"username": username,
|
72 |
+
"password": password,
|
73 |
+
}
|
74 |
+
response = requests.post(f"{config.api_url}/api/security/token", data=data)
|
75 |
+
return response
|
76 |
+
|
77 |
+
|
78 |
+
def get_user(username, password):
|
79 |
+
user = get_existing_user(username)
|
80 |
+
if user is None:
|
81 |
+
user = create_new_user(username, password)
|
82 |
+
return user
|
83 |
+
|
84 |
+
|
85 |
+
def create_new_user(username, password):
|
86 |
+
users = list(rg.User.list())
|
87 |
+
num_users = len(users)
|
88 |
+
first_name = str(num_users)
|
89 |
+
user = rg.User.create(
|
90 |
+
username=username,
|
91 |
+
first_name=first_name,
|
92 |
+
last_name="-",
|
93 |
+
password=password,
|
94 |
+
role="annotator",
|
95 |
+
)
|
96 |
+
return user
|
97 |
+
|
98 |
+
|
99 |
+
def get_existing_user(username):
|
100 |
+
for user in rg.User.list():
|
101 |
+
if user.username == username:
|
102 |
+
return user
|
103 |
+
return None
|
104 |
+
|
105 |
+
|
106 |
+
def add_workspace(user):
|
107 |
+
try:
|
108 |
+
workspace = rg.Workspace.create(name=user.username)
|
109 |
+
workspace.add_user(user.id)
|
110 |
+
except ValueError:
|
111 |
+
print("Workspace for this user already exists.")
|
112 |
+
workspace = rg.Workspace.from_name(name=user.username)
|
113 |
+
return workspace
|
114 |
+
|
115 |
+
|
116 |
+
def get_records_to_add(user, records, num_dataset):
|
117 |
+
user_index = int(user.first_name)
|
118 |
+
shifts = user_index + num_dataset
|
119 |
+
records_to_add = assign_samples(records, config.samples_per_group, shifts)
|
120 |
+
return records_to_add
|
121 |
+
|
122 |
+
|
123 |
+
def assign_samples(records, samples_per_group, shifts):
|
124 |
+
start = (samples_per_group * (shifts - 1)) % len(records)
|
125 |
+
end = start + samples_per_group
|
126 |
+
if end <= len(records):
|
127 |
+
return records[start:end]
|
128 |
+
end = end % len(records)
|
129 |
+
return records[start:] + records[:end]
|
130 |
+
|
131 |
+
|
132 |
+
def get_num_datasets(user):
|
133 |
+
header = {
|
134 |
+
"X-Argilla-Api-Key": user.api_key
|
135 |
+
}
|
136 |
+
response = requests.get(f"{config.api_url}/api/v1/me/datasets", headers=header)
|
137 |
+
datasets = response.json()["items"]
|
138 |
+
num_datasets = len(datasets)
|
139 |
+
return num_datasets
|
140 |
+
|
141 |
+
|
142 |
+
def get_dataset_name(num_datasets):
|
143 |
+
dataset_name = f"dataset-group-{num_datasets + 1}"
|
144 |
+
return dataset_name
|
145 |
+
|
146 |
+
|
147 |
+
def check_inputs(username, password):
|
148 |
+
if not re.match(r"^(?!-|_)[a-z0-9-_]+$", username):
|
149 |
+
return "Your username does not match the pattern '^(?!-|_)[a-z0-9-_]+$', please fix and try again.\n" \
|
150 |
+
"Tips:\n" \
|
151 |
+
"1. Make it lowercase.\n" \
|
152 |
+
"2. Use only english.\n" \
|
153 |
+
"3. Use only '_' as special symbol."
|
154 |
+
if len(password) < 8:
|
155 |
+
return "Your password is less than 8 symbols, please fix and try again."
|
156 |
+
return None
|
157 |
+
|
158 |
+
|
159 |
+
def get_response_message(username, password, dataset_name):
|
160 |
+
response_text = f"Successfully created/updated your profile at {config.api_url}. " \
|
161 |
+
f"Use the following credential to login:\n" \
|
162 |
+
f"Username: {username}\n" \
|
163 |
+
f"Password: {password}\n\n" \
|
164 |
+
f"You will find the dataset '{dataset_name}' with {config.samples_per_group} new samples.\n" \
|
165 |
+
f"Please take your time to annotate the data. If you finished all provided samples, " \
|
166 |
+
f"simply use the same credentials in this Gradio Space and we will add you another dataset with " \
|
167 |
+
f"new samples. "
|
168 |
+
return response_text
|