AlekseyKorshuk commited on
Commit
b58f9ec
Β·
1 Parent(s): 7d65544

initial commit

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +57 -0
  3. config.py +7 -0
  4. requirements.txt +3 -0
  5. utils.py +168 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Role Play Crowdsource Signup
3
- emoji: 🌍
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
  title: Role Play Crowdsource Signup
3
+ emoji: πŸ‘₯
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+
4
+ import argilla as rg
5
+ from huggingface_hub import login
6
+ import gradio as gr
7
+
8
+ import config
9
+ import utils
10
+
11
+
12
+ def main():
13
+ records = init()
14
+
15
+ with gr.Blocks() as demo:
16
+ gr.Markdown("# Role-Play Crowdsource\n"
17
+ "TODO")
18
+ username = gr.Textbox(label="Username", placeholder="alekseykorshuk")
19
+ password = gr.Textbox(label="Password", placeholder="12345678")
20
+ btn = gr.Button("Run")
21
+ status = gr.Textbox(label="Status")
22
+
23
+ btn.click(
24
+ fn=partial(signup, records=records),
25
+ inputs=[username, password], outputs=status)
26
+ demo.launch()
27
+
28
+
29
+ def init():
30
+ rg.init(
31
+ api_url=config.api_url,
32
+ api_key=config.api_key
33
+ )
34
+ login(config.hf_token)
35
+ records = utils.get_records()
36
+ return records
37
+
38
+
39
+ def signup(username, password, records):
40
+ inputs_correctness = utils.check_inputs(username, password)
41
+ if inputs_correctness:
42
+ return inputs_correctness
43
+ user = utils.get_user(username, password)
44
+ response = utils.authorize_user(username, password)
45
+ if response.status_code != 200:
46
+ return "Unable to authorize, please check your credentials."
47
+
48
+ workspace = utils.add_workspace(user)
49
+ num_datasets = utils.get_num_datasets(user)
50
+ records_to_add = utils.get_records_to_add(user, records, num_datasets)
51
+ dataset_name = utils.push_dataset(num_datasets, workspace, records_to_add)
52
+ response_text = utils.get_response_message(username, password, dataset_name)
53
+ return response_text
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ samples_per_group = int(os.environ.get("SAMPLES_PER_GROUP", 250))
4
+ api_url = os.environ.get("ARGILLA_API_URL")
5
+ api_key = os.environ.get("ARGILLA_API_KEY")
6
+ hf_dataset_path = os.environ.get("HF_DATASET_PATH")
7
+ hf_token = os.environ.get("HF_TOKEN")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==3.24.1
2
+ datasets==2.11.0
3
+ argilla==1.11.0
utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import requests
4
+ import tqdm
5
+ from datasets import load_dataset
6
+ import argilla as rg
7
+
8
+ import config
9
+
10
+
11
+ def push_dataset(num_datasets, workspace, records_to_add):
12
+ dataset = rg.FeedbackDataset(
13
+ guidelines=_get_guidelines(),
14
+ fields=_get_fields(),
15
+ questions=_get_questions()
16
+ )
17
+ dataset.add_records(records_to_add)
18
+ dataset_name = get_dataset_name(num_datasets)
19
+ dataset.push_to_argilla(name=dataset_name, workspace=workspace.name, show_progress=True)
20
+ return dataset_name
21
+
22
+
23
+ def _get_fields():
24
+ fields = [
25
+ rg.TextField(name="system", title="Character description"),
26
+ rg.TextField(name="conversation_history", title="Conversation history"),
27
+ ]
28
+ return fields
29
+
30
+
31
+ def get_records():
32
+ dataset = _get_dataset()
33
+ records = [
34
+ rg.FeedbackRecord(
35
+ fields={
36
+ "system": record["system"],
37
+ "conversation_history": record["conversation_history"],
38
+ },
39
+ external_id=record['external_id']
40
+ )
41
+ for record in tqdm.tqdm(dataset)
42
+ ]
43
+ return records
44
+
45
+
46
+ def _get_dataset():
47
+ dataset = load_dataset(config.hf_dataset_path, split="train")
48
+ return dataset
49
+
50
+
51
+ def _get_questions():
52
+ questions = [
53
+ rg.TextQuestion(
54
+ name="new-response",
55
+ title="Character response:",
56
+ description="Write the final version of the Character response, making sure that it matches the character "
57
+ "description and makes sense for the conversation history.",
58
+ required=True
59
+ )
60
+ ]
61
+ return questions
62
+
63
+
64
+ def _get_guidelines():
65
+ guidelines = None
66
+ return guidelines
67
+
68
+
69
+ def authorize_user(username, password):
70
+ data = {
71
+ "username": username,
72
+ "password": password,
73
+ }
74
+ response = requests.post(f"{config.api_url}/api/security/token", data=data)
75
+ return response
76
+
77
+
78
+ def get_user(username, password):
79
+ user = get_existing_user(username)
80
+ if user is None:
81
+ user = create_new_user(username, password)
82
+ return user
83
+
84
+
85
+ def create_new_user(username, password):
86
+ users = list(rg.User.list())
87
+ num_users = len(users)
88
+ first_name = str(num_users)
89
+ user = rg.User.create(
90
+ username=username,
91
+ first_name=first_name,
92
+ last_name="-",
93
+ password=password,
94
+ role="annotator",
95
+ )
96
+ return user
97
+
98
+
99
+ def get_existing_user(username):
100
+ for user in rg.User.list():
101
+ if user.username == username:
102
+ return user
103
+ return None
104
+
105
+
106
+ def add_workspace(user):
107
+ try:
108
+ workspace = rg.Workspace.create(name=user.username)
109
+ workspace.add_user(user.id)
110
+ except ValueError:
111
+ print("Workspace for this user already exists.")
112
+ workspace = rg.Workspace.from_name(name=user.username)
113
+ return workspace
114
+
115
+
116
+ def get_records_to_add(user, records, num_dataset):
117
+ user_index = int(user.first_name)
118
+ shifts = user_index + num_dataset
119
+ records_to_add = assign_samples(records, config.samples_per_group, shifts)
120
+ return records_to_add
121
+
122
+
123
+ def assign_samples(records, samples_per_group, shifts):
124
+ start = (samples_per_group * (shifts - 1)) % len(records)
125
+ end = start + samples_per_group
126
+ if end <= len(records):
127
+ return records[start:end]
128
+ end = end % len(records)
129
+ return records[start:] + records[:end]
130
+
131
+
132
+ def get_num_datasets(user):
133
+ header = {
134
+ "X-Argilla-Api-Key": user.api_key
135
+ }
136
+ response = requests.get(f"{config.api_url}/api/v1/me/datasets", headers=header)
137
+ datasets = response.json()["items"]
138
+ num_datasets = len(datasets)
139
+ return num_datasets
140
+
141
+
142
+ def get_dataset_name(num_datasets):
143
+ dataset_name = f"dataset-group-{num_datasets + 1}"
144
+ return dataset_name
145
+
146
+
147
+ def check_inputs(username, password):
148
+ if not re.match(r"^(?!-|_)[a-z0-9-_]+$", username):
149
+ return "Your username does not match the pattern '^(?!-|_)[a-z0-9-_]+$', please fix and try again.\n" \
150
+ "Tips:\n" \
151
+ "1. Make it lowercase.\n" \
152
+ "2. Use only english.\n" \
153
+ "3. Use only '_' as special symbol."
154
+ if len(password) < 8:
155
+ return "Your password is less than 8 symbols, please fix and try again."
156
+ return None
157
+
158
+
159
+ def get_response_message(username, password, dataset_name):
160
+ response_text = f"Successfully created/updated your profile at {config.api_url}. " \
161
+ f"Use the following credential to login:\n" \
162
+ f"Username: {username}\n" \
163
+ f"Password: {password}\n\n" \
164
+ f"You will find the dataset '{dataset_name}' with {config.samples_per_group} new samples.\n" \
165
+ f"Please take your time to annotate the data. If you finished all provided samples, " \
166
+ f"simply use the same credentials in this Gradio Space and we will add you another dataset with " \
167
+ f"new samples. "
168
+ return response_text