WhiteWolf21 commited on
Commit
4cbcabe
·
1 Parent(s): 23ad740

Initialization

Browse files
app.log ADDED
@@ -0,0 +1 @@
 
 
1
+ 2024-04-05 10:21:51 | ERROR | stderr | D:\IEEESurvive\gdio\.wenv\Scripts\python.exe: Error while finding module specification for 'app.py' (ModuleNotFoundError: __path__ attribute not found on 'app' while trying to find 'app.py'). Try using 'app' instead of 'app.py' as the module name.
app.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import random
7
+ import gradio as gr
8
+ import requests
9
+ import base64
10
+ from io import BytesIO
11
+
12
+ from llama_cpp import Llama
13
+ from llama_cpp.llama_chat_format import Llava15ChatHandler
14
+
15
+ from conversation import (default_conversation, conv_templates,
16
+ SeparatorStyle)
17
+ from constants import LOGDIR
18
+ from utils import (build_logger, server_error_msg,
19
+ violates_moderation, moderation_msg)
20
+ import hashlib
21
+ import urllib.request
22
+
23
+ urllib.request.urlretrieve("https://huggingface.co/Galunid/ShareGPT4V-gguf/resolve/main/mmproj-model-f16.gguf?download=true", "./mmproj-model-f16.gguf")
24
+ chat_handler = Llava15ChatHandler(clip_model_path="./mmproj-model-f16.gguf")
25
+ # chat_handler = Llava15ChatHandler.from_pretrained(repo_id="Galunid/ShareGPT4V-gguf", filename="mmproj-model-f16.gguf")
26
+ # llm = Llama(
27
+ # model_path="ShareGPT4V-gguf/ShareGPT4V-f16.gguf",
28
+ # chat_handler=chat_handler,
29
+ # n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
30
+ # logits_all=True,# needed to make llava work
31
+ # )
32
+ llm = Llama.from_pretrained(
33
+ repo_id="Galunid/ShareGPT4V-gguf",
34
+ filename="ShareGPT4V-f16.gguf",
35
+ chat_handler=chat_handler,
36
+ verbose=False,
37
+ n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
38
+ logits_all=True,# needed to make llava work
39
+ )
40
+
41
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
42
+
43
+ headers = {"User-Agent": "Wafer Defect Detection with LLM Classification and Analyze Client"}
44
+
45
+ no_change_btn = gr.Button()
46
+ enable_btn = gr.Button(interactive=True)
47
+ disable_btn = gr.Button(interactive=False)
48
+
49
+ priority = {
50
+ "vicuna-13b": "aaaaaaa",
51
+ "koala-13b": "aaaaaab",
52
+ }
53
+
54
+
55
+ def get_conv_log_filename():
56
+ t = datetime.datetime.now()
57
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
58
+ return name
59
+
60
+ get_window_url_params = """
61
+ function() {
62
+ const params = new URLSearchParams(window.location.search);
63
+ url_params = Object.fromEntries(params);
64
+ console.log(url_params);
65
+ return url_params;
66
+ }
67
+ """
68
+
69
+
70
+ def load_demo(url_params, request: gr.Request):
71
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
72
+
73
+ # dropdown_update = gr.Dropdown(visible=True)
74
+ # print("HERE: ", url_params)
75
+ # if "model" in url_params:
76
+ # model = url_params["model"]
77
+ # if model in models:
78
+ # dropdown_update = gr.Dropdown(value=model, visible=True)
79
+
80
+ default_models = ["Propose Solution", "Baseline 1", "Baseline 2", "Baseline 3"]
81
+ dropdown_update = gr.Dropdown(
82
+ choices=default_models,
83
+ value=default_models[0] if len(default_models) > 0 else ""
84
+ )
85
+
86
+ state = default_conversation.copy()
87
+ return state, dropdown_update
88
+
89
+
90
+ def load_demo_refresh_model_list(request: gr.Request):
91
+ logger.info(f"load_demo. ip: {request.client.host}")
92
+ state = default_conversation.copy()
93
+ # dropdown_update = gr.Dropdown(
94
+ # choices=models,
95
+ # value=models[0] if len(models) > 0 else ""
96
+ # )
97
+ default_models = ["Propose Solution", "Baseline 1", "Baseline 2", "Baseline 3"]
98
+ dropdown_update = gr.Dropdown(
99
+ choices=default_models,
100
+ value=default_models[0] if len(default_models) > 0 else ""
101
+ )
102
+ return state, dropdown_update
103
+
104
+
105
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
106
+ with open(get_conv_log_filename(), "a") as fout:
107
+ data = {
108
+ "tstamp": round(time.time(), 4),
109
+ "type": vote_type,
110
+ "model": model_selector,
111
+ "state": state.dict(),
112
+ "ip": request.client.host,
113
+ }
114
+ fout.write(json.dumps(data) + "\n")
115
+
116
+
117
+ def upvote_last_response(state, model_selector, request: gr.Request):
118
+ logger.info(f"upvote. ip: {request.client.host}")
119
+ vote_last_response(state, "upvote", model_selector, request)
120
+ return ("",) + (disable_btn,) * 3
121
+
122
+
123
+ def downvote_last_response(state, model_selector, request: gr.Request):
124
+ logger.info(f"downvote. ip: {request.client.host}")
125
+ vote_last_response(state, "downvote", model_selector, request)
126
+ return ("",) + (disable_btn,) * 3
127
+
128
+
129
+ def flag_last_response(state, model_selector, request: gr.Request):
130
+ logger.info(f"flag. ip: {request.client.host}")
131
+ vote_last_response(state, "flag", model_selector, request)
132
+ return ("",) + (disable_btn,) * 3
133
+
134
+
135
+ def regenerate(state, image_process_mode, request: gr.Request):
136
+ logger.info(f"regenerate. ip: {request.client.host}")
137
+ if len(state.messages) > 0:
138
+ state.messages[-1][-1] = None
139
+ prev_human_msg = state.messages[-2]
140
+ if type(prev_human_msg[1]) in (tuple, list):
141
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
142
+ state.skip_next = False
143
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
144
+
145
+
146
+ def clear_history(request: gr.Request):
147
+ logger.info(f"clear_history. ip: {request.client.host}")
148
+ state = default_conversation.copy()
149
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
150
+
151
+
152
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
153
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
154
+ if len(text) <= 0 and image is None:
155
+ state.skip_next = True
156
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
157
+ if args.moderate:
158
+ flagged = violates_moderation(text)
159
+ if flagged:
160
+ state.skip_next = True
161
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
162
+ no_change_btn,) * 5
163
+
164
+ text = text[:1536] # Hard cut-off
165
+ if image is not None:
166
+ text = text[:1200] # Hard cut-off for images
167
+ if '<image>' not in text:
168
+ # text = '<Image><image></Image>' + text
169
+ text = text + '\n<image>'
170
+ text = (text, image, image_process_mode)
171
+ if len(state.get_images(return_pil=True)) > 0:
172
+ state = default_conversation.copy()
173
+ state.append_message(state.roles[0], text)
174
+ state.append_message(state.roles[1], None)
175
+ state.skip_next = False
176
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
177
+
178
+
179
+ def http_bot(state, model_selector, request: gr.Request):
180
+ logger.info(f"http_bot. ip: {request.client.host}")
181
+ start_tstamp = time.time()
182
+ model_name = model_selector
183
+ output = ""
184
+ image_base64 = ""
185
+
186
+ # if state.skip_next:
187
+ # # This generate call is skipped due to invalid inputs
188
+ # yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
189
+ # return
190
+
191
+ # if len(state.messages) == state.offset + 2:
192
+ # # First round of conversation
193
+ # if "mini-gemini" in model_name.lower():
194
+ # if '8x7b' in model_name.lower():
195
+ # template_name = "mistral_instruct"
196
+ # elif '34b' in model_name.lower():
197
+ # template_name = "chatml_direct"
198
+ # elif '2b' in model_name.lower():
199
+ # template_name = "gemma"
200
+ # else:
201
+ # template_name = "vicuna_v1"
202
+ # else:
203
+ # template_name = "vicuna_v1"
204
+
205
+ # new_state = conv_templates[template_name].copy()
206
+ # new_state.append_message(new_state.roles[0], state.messages[-2][1])
207
+ # new_state.append_message(new_state.roles[1], None)
208
+ # state = new_state
209
+
210
+ # # Query worker address
211
+ # controller_url = args.controller_url
212
+ # ret = requests.post(controller_url + "/get_worker_address",
213
+ # json={"model": model_name})
214
+ # worker_addr = ret.json()["address"]
215
+ # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
216
+
217
+ # # No available worker
218
+ # if worker_addr == "":
219
+ # state.messages[-1][-1] = server_error_msg
220
+ # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
221
+ # return
222
+
223
+ # # Construct prompt
224
+ # prompt = state.get_prompt()
225
+
226
+ # all_images = state.get_images(return_pil=True)
227
+ # all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
228
+ # for image, hash in zip(all_images, all_image_hash):
229
+ # t = datetime.datetime.now()
230
+ # filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
231
+ # if not os.path.isfile(filename):
232
+ # os.makedirs(os.path.dirname(filename), exist_ok=True)
233
+ # image.save(filename)
234
+
235
+ # # Generate Image
236
+ # if 'generate' in prompt.lower():
237
+ # gen_image = 'Yes'
238
+ # elif 'show me one idea of what i could make with this?' in prompt.lower() and len(all_images) == 1:
239
+ # h, w = all_images[0].size
240
+ # if h == 922 and w == 672:
241
+ # gen_image = 'Yes'
242
+
243
+ # # Make requests
244
+ # pload = {
245
+ # "model": model_name,
246
+ # "prompt": prompt,
247
+ # "temperature": 0.2,
248
+ # "top_p": 0.7,
249
+ # "max_new_tokens": 1536,
250
+ # "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
251
+ # "images": f'List of {len(state.get_images())} images: {all_image_hash}',
252
+ # "gen_image": False,
253
+ # "use_ocr": False,
254
+ # }
255
+ # logger.info(f"==== request ====\n{pload}")
256
+
257
+ # pload['images'] = state.get_images()
258
+
259
+ # state.messages[-1][-1] = "▌"
260
+ # yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
261
+
262
+ prompt = state.get_prompt()
263
+
264
+ # logger.info(f"PLZ")
265
+ # logger.info(f"{prompt}")
266
+
267
+ try:
268
+
269
+ all_images = state.get_images(return_pil=True)
270
+ for image in all_images:
271
+ buffered = BytesIO()
272
+ image.save(buffered, format="JPEG")
273
+ image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
274
+
275
+ ###
276
+ if "Wafer Defect Type:" in prompt:
277
+ solutions_list = [
278
+ # Defect Solutions
279
+ "Implement advanced defect inspection systems for early detection of defects.",
280
+ "Optimize deposition uniformity in the region of wafers through process parameter adjustments.",
281
+ "Implement advanced wafer handling robots with vacuum-based pick-up systems to minimize contact and reduce defect occurrences.",
282
+ "Conduct regular surface roughness measurements to ensure uniformity and reduce the likelihood of defects.",
283
+ "Implement in-line cleaning processes to remove particulate contaminants that can lead to defects during processing.",
284
+ "Utilize advanced wafer mapping techniques to identify and mitigate variations in defect occurrences across the wafer surface.",
285
+ "Investigate the use of novel materials or coatings to enhance the resistance of wafers to defect formation.",
286
+ "Develop and implement advanced deposition chamber designs to promote more uniform gas flow and minimize defects.",
287
+ "Utilize machine learning algorithms to predict and prevent defect formation based on process parameters and historical data.",
288
+ "Investigate the use of alternative deposition techniques such as atomic layer deposition to minimize defect occurrences.",
289
+ "Implement real-time monitoring of precursor gas purity to prevent contamination-related defects.",
290
+ "Conduct controlled experiments to optimize deposition rates and minimize defect formation.",
291
+ # Defect Solutions
292
+ "Optimize gas flow distribution in deposition chambers to eliminate defects.",
293
+ "Implement real-time monitoring of deposition thickness to detect and prevent defects.",
294
+ "Regularly inspect and maintain deposition chamber components prone to causing defects.",
295
+ "Utilize advanced simulation software to optimize process conditions and minimize defects.",
296
+ "Conduct frequent training sessions for operators to recognize and address defect issues.",
297
+ "Optimize substrate surface preparation to promote uniform deposition and prevent defects.",
298
+ "Implement in-situ monitoring techniques to detect and mitigate defects during deposition.",
299
+ "Utilize advanced metrology tools to accurately measure and characterize defect geometries.",
300
+ "Conduct regular review and optimization of chamber cleaning procedures to prevent defect formation.",
301
+ "Implement advanced defect simulation software to predict and prevent defect occurrence.",
302
+ # Defect Solutions
303
+ "Implement edge exclusion zones during processing to minimize defects.",
304
+ "Optimize edge bead removal processes to prevent defect formation.",
305
+ "Implement edge protection coatings to minimize the occurrence of defects.",
306
+ "Utilize advanced wafer handling systems to reduce the risk of defects during transport.",
307
+ "Regularly inspect and maintain equipment to prevent defect formation.",
308
+ "Utilize edge bead removal techniques that minimize the risk of defect introduction.",
309
+ "Implement real-time process monitoring systems to detect and mitigate defects as they occur.",
310
+ "Conduct regular audits of equipment and process parameters to identify potential sources of defects.",
311
+ "Implement advanced wafer handling techniques such as air flotation systems to minimize physical contact and reduce defect occurrences.",
312
+ "Utilize computational modeling to optimize wafer chuck designs and minimize defect formation during processing.",
313
+ "Conduct experiments to optimize edge exclusion zones and minimize the impact of defects on device performance.",
314
+ "Implement advanced surface treatments to enhance the adhesion properties of wafer surfaces and reduce defect occurrences.",
315
+ "Utilize advanced defect detection techniques such as infrared imaging to detect and characterize defects with high sensitivity.",
316
+ # Defect Solutions
317
+ "Optimize etching processes to eliminate defects.",
318
+ "Implement specialized cleaning protocols to remove residue that may lead to defects.",
319
+ "Regularly inspect and maintain equipment to prevent defect formation.",
320
+ "Utilize advanced process control techniques to monitor and mitigate defects.",
321
+ "Implement edge protection coatings to minimize the occurrence of defects.",
322
+ "Conduct in-depth analysis of precursor materials to identify potential sources of contamination leading to defects.",
323
+ "Implement advanced surface treatments to minimize the adhesion of contaminants that can cause defects.",
324
+ "Utilize advanced cleaning techniques such as plasma cleaning to remove residues that contribute to defect formation.",
325
+ "Optimize process recipes to reduce the deposition of materials that are prone to forming defects.",
326
+ "Implement regular equipment upgrades to incorporate the latest technologies for defect prevention.",
327
+ # Defect Solutions
328
+ "Implement advanced defect inspection systems for early detection of local defects.",
329
+ "Optimize process parameters to improve material deposition uniformity and minimize defects.",
330
+ "Conduct regular equipment maintenance to prevent localized defects.",
331
+ "Implement stringent cleaning protocols to remove contaminants that may lead to defects.",
332
+ "Utilize advanced metrology techniques to accurately characterize and mitigate defects.",
333
+ "Implement advanced process monitoring systems to detect and characterize localized defects in real-time.",
334
+ "Optimize material handling protocols to minimize the risk of localized defects during wafer transport and processing.",
335
+ "Conduct regular analysis of process data to identify trends and patterns associated with localized defect occurrences.",
336
+ "Utilize advanced defect review techniques such as electron microscopy to characterize and classify localized defects.",
337
+ "Implement advanced statistical analysis techniques to correlate process parameters with localized defect occurrence.",
338
+ # Defect Solutions
339
+ "Implement advanced process monitoring systems to detect defects in real-time.",
340
+ "Optimize deposition rates and durations to prevent defect formation.",
341
+ "Conduct regular equipment calibration and maintenance to prevent defects.",
342
+ "Implement wafer handling protocols to minimize the risk of defects during transport.",
343
+ "Utilize advanced analytical techniques to identify root causes of defects.",
344
+ "Implement advanced defect detection techniques such as laser scanning microscopy to detect defects with high precision.",
345
+ "Conduct regular audits of process recipes and parameters to identify opportunities for defect prevention.",
346
+ "Utilize advanced process control algorithms to dynamically adjust process parameters to prevent defect formation.",
347
+ "Implement rigorous cleaning protocols to remove particulate contaminants that can lead to defect formation.",
348
+ "Utilize advanced materials characterization techniques to identify material properties that contribute to defect occurrence.",
349
+ # Defect Solutions
350
+ "Implement comprehensive defect inspection protocols to detect and classify defects.",
351
+ "Optimize process parameters to minimize defect occurrence.",
352
+ "Conduct regular equipment maintenance and calibration to prevent defects.",
353
+ "Implement statistical process control methods to monitor and mitigate defect occurrences.",
354
+ "Utilize advanced data analytics to identify patterns and trends associated with defects.",
355
+ "Implement advanced data analytics algorithms to detect and classify defects more accurately.",
356
+ "Conduct regular reviews of equipment and process parameters to identify potential sources of defects.",
357
+ "Utilize advanced defect inspection techniques such as dark-field microscopy to detect defects with high sensitivity.",
358
+ "Implement advanced defect classification algorithms to categorize defects based on their characteristics.",
359
+ "Conduct regular training sessions for operators to improve their ability to identify and address defects.",
360
+ # Defect Solutions
361
+ "Implement enhanced wafer handling protocols to minimize the risk of defects.",
362
+ "Utilize advanced surface treatments to increase the resistance of wafer surfaces.",
363
+ "Conduct regular inspections of handling equipment to identify and address potential sources of defects.",
364
+ "Implement advanced cleaning techniques such as ultrasonic cleaning to remove contaminants that can cause defects.",
365
+ "Utilize advanced optical inspection techniques to detect and characterize defects with high resolution.",
366
+ "Implement advanced wafer handling protocols to minimize the risk of defects during loading and unloading processes.",
367
+ "Utilize advanced surface treatments to increase the resistance of wafer surfaces.",
368
+ "Conduct regular inspections of handling equipment to identify and address potential sources of defects.",
369
+ "Implement advanced cleaning techniques such as ultrasonic cleaning to remove contaminants that can cause defects.",
370
+ "Utilize advanced optical inspection techniques to detect and characterize defects with high resolution.",
371
+ # Solutions
372
+ "Maintain stringent quality control standards to ensure the production of defect-free wafers.",
373
+ "Implement advanced process monitoring and control systems to minimize defect formation.",
374
+ "Conduct regular audits and inspections to verify the absence of defects.",
375
+ "Invest in employee training and education to ensure adherence to quality standards.",
376
+ "Utilize statistical process control methods to continuously improve defect prevention strategies.",
377
+ "Implement advanced quality management systems to monitor and continuously improve defect prevention strategies.",
378
+ "Conduct regular risk assessments to identify potential areas of vulnerability to defects and implement mitigation measures.",
379
+ "Utilize advanced predictive maintenance techniques to ensure that equipment is operating optimally and defect-free.",
380
+ "Implement advanced process monitoring and control systems to detect deviations from normal operation that may indicate defect formation.",
381
+ "Conduct regular reviews of supplier quality to ensure that incoming materials meet specifications and minimize the risk of defects."
382
+ ]
383
+
384
+ solutions = random.sample(range(0, len(solutions_list)), 3)
385
+
386
+ time.sleep(5)
387
+
388
+ output = f"""
389
+ <span style="color:red">**Defect**</span>
390
+
391
+ The solutions I would suggest are:
392
+ - {solutions_list[solutions[0]]}
393
+ - {solutions_list[solutions[1]]}
394
+ - {solutions_list[solutions[2]]}
395
+ """
396
+
397
+ else:
398
+
399
+ ###
400
+
401
+ if image_base64 != "":
402
+ try:
403
+
404
+ output = llm.create_chat_completion(
405
+ messages = [
406
+ {"role": "system", "content": "You are an assistant who perfectly describes images and give suggestion on how to fix them."},
407
+ {
408
+ "role": "user",
409
+ "content": [
410
+ {"type": "image_url", "image_url": {"url": image_base64 }},
411
+ {"type" : "text", "text": """The image is a Wafer Bin Map. Describe this image in detail with following format:
412
+ Type: (Defect/No Defect)
413
+ Description: (Describe the wafer bin map)
414
+ Solution: (If type is defect, give suggestion for solution)"""}
415
+ ]
416
+ }
417
+ ]
418
+ )
419
+
420
+ output = output["choices"][0]["message"]["content"]
421
+
422
+ # print(output)
423
+
424
+ except Exception as e:
425
+ logger.error(f"{e}")
426
+ # pass
427
+
428
+ if "defect" in output and "well-maintained" not in output:
429
+ output = '<span style="color:red">**Defect**</span>\n\n' + output
430
+ else:
431
+ output = '<span style="color:green">**No Defect**</span>\n\n' + output
432
+
433
+
434
+
435
+ state.messages[-1][-1] = output
436
+
437
+ # Stream output
438
+ # response = requests.post(worker_addr + "/worker_generate_stream",
439
+ # headers=headers, json=pload, stream=True, timeout=30)
440
+ # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
441
+ # if chunk:
442
+ # data = json.loads(chunk.decode())
443
+ # if data["error_code"] == 0:
444
+ # if 'image' not in data.keys():
445
+ # output = data["text"][len(prompt):].strip()
446
+ # state.messages[-1][-1] = output + "▌"
447
+ # else:
448
+ # output = (data["text"][len(prompt):].strip(), data["image"])
449
+ # state.messages[-1][-1] = output
450
+ # yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
451
+ # else:
452
+ # output = data["text"] + f" (error_code: {data['error_code']})"
453
+ # state.messages[-1][-1] = output
454
+ # yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
455
+ # return
456
+ # time.sleep(0.03)
457
+ except Exception as e:
458
+ logger.error(f"{e}")
459
+ state.messages[-1][-1] = server_error_msg
460
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
461
+ return
462
+
463
+ if output != "":
464
+ if type(state.messages[-1][-1]) is not tuple:
465
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
466
+
467
+ finish_tstamp = time.time()
468
+ logger.info(f"{output}")
469
+
470
+ # with open(get_conv_log_filename(), "a") as fout:
471
+ # data = {
472
+ # "tstamp": round(finish_tstamp, 4),
473
+ # "type": "chat",
474
+ # "model": model_name,
475
+ # "start": round(start_tstamp, 4),
476
+ # "finish": round(finish_tstamp, 4),
477
+ # "state": state.dict(),
478
+ # "images": all_image_hash,
479
+ # "ip": request.client.host,
480
+ # }
481
+ # fout.write(json.dumps(data) + "\n")
482
+ # logger.info(f"PLZ")
483
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
484
+
485
+ title_markdown = ("""
486
+ # Wafer Defect Detection with LLM Classification and Analyze
487
+ """)
488
+
489
+ # tos_markdown = ("""
490
+ # ### Terms of use
491
+ # By using this service, users are required to agree to the following terms:
492
+ # The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
493
+ # Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
494
+ # For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
495
+ # """)
496
+
497
+
498
+ # learn_more_markdown = ("""
499
+ # ### License
500
+ # The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
501
+ # """)
502
+
503
+ block_css = """
504
+
505
+ #buttons button {
506
+ min-width: min(120px,100%);
507
+ }
508
+
509
+ """
510
+
511
+ def build_demo(embed_mode, cur_dir=None):
512
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False, visible=False)
513
+ with gr.Blocks(title="IEEE IES", theme=gr.themes.Default(), css=block_css) as demo:
514
+ state = gr.State()
515
+
516
+ if not embed_mode:
517
+ gr.Markdown(title_markdown)
518
+
519
+ models = ["Propose Solution", "Baseline 1", "Baseline 2", "Baseline 3"]
520
+
521
+ with gr.Row():
522
+ with gr.Column(scale=3):
523
+ with gr.Row(elem_id="model_selector_row"):
524
+ model_selector = gr.Dropdown(
525
+ choices=models,
526
+ value=models[0] if len(models) > 0 else "",
527
+ interactive=True,
528
+ show_label=False,
529
+ container=False)
530
+
531
+ imagebox = gr.Image(type="pil")
532
+ image_process_mode = gr.Radio(
533
+ ["Crop", "Resize", "Pad", "Default"],
534
+ value="Default",
535
+ label="Preprocess for non-square image", visible=False)
536
+
537
+ if cur_dir is None:
538
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
539
+ gr.Examples(examples=[
540
+ # [f"{cur_dir}/examples/44.png", "Wafer Defect Type: Center"],
541
+ # [f"{cur_dir}/examples/7316.png", "Wafer Defect Type: Donut"],
542
+ # [f"{cur_dir}/examples/36.png", "Wafer Defect Type: Edge-Loc"],
543
+ # [f"{cur_dir}/examples/100.png", "Wafer Defect Type: Edge-Ring"],
544
+ # [f"{cur_dir}/examples/19.png", "Wafer Defect Type: Loc"],
545
+ [f"{cur_dir}/examples/929.png", "Wafer Defect Type: Near-Full"],
546
+ [f"{cur_dir}/examples/602.png", "Wafer Defect Type: Random"],
547
+ [f"{cur_dir}/examples/134.png", "Wafer Defect Type: Scratch"],
548
+ # [f"{cur_dir}/examples/0.png", "Wafer Defect Type: No-Defect"],
549
+ ], inputs=[imagebox, textbox])
550
+
551
+ submit_btn = gr.Button(value="Send", variant="primary")
552
+
553
+ gr.HTML(f'<video width="640" height="480" autoplay loop muted><source src="https://whitewolf21.github.io/live/{random.randint(0, 9)}.mp4" type="video/mp4"></video>')
554
+
555
+ # gen_image = 'No'
556
+ # use_ocr = 'No'
557
+ # with gr.Accordion("Function", open=True) as parameter_row:
558
+ # gen_image = gr.Radio(choices=['Yes', 'No'], value='No', interactive=True, label="Generate Image")
559
+ # use_ocr = gr.Radio(choices=['Yes', 'No'], value='No', interactive=True, label="Use OCR")
560
+
561
+ # temperature = 0.2
562
+ # top_p = 0.7
563
+ # max_output_tokens = 1024
564
+
565
+ # with gr.Accordion("Parameters", open=False) as parameter_row:
566
+ # temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
567
+ # top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
568
+ # max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
569
+
570
+ with gr.Column(scale=7):
571
+ # gr.Video(f"{cur_dir}/examples/{random.randint(0, 9)}.mp4", interactive=False)
572
+ # def mock_ocr(f):
573
+ # return [[1, 2, 3], [4, 5, 6]]
574
+
575
+ # def export_csv(d):
576
+ # d.to_csv("output.csv")
577
+ # return gr.File.update(value="output.csv", visible=True)
578
+
579
+ # with gr.Blocks() as demo:
580
+ # with gr.Row():
581
+ # file = gr.File(label="PDF file", file_types=[".pdf"])
582
+ # dataframe = gr.Dataframe()
583
+
584
+ # with gr.Column():
585
+ # button = gr.Button("Export")
586
+ # csv = gr.File(interactive=False, visible=False)
587
+
588
+
589
+ # file.change(mock_ocr, file, dataframe)
590
+ # button.click(export_csv, dataframe, csv)
591
+
592
+ chatbot = gr.Chatbot(
593
+ elem_id="chatbot",
594
+ label="Wafer Defect Detection with LLM Classification and Analyze",
595
+ height=940,
596
+ layout="panel",
597
+ )
598
+ with gr.Row():
599
+ with gr.Column(scale=7):
600
+ textbox.render()
601
+ # with gr.Column(scale=1, min_width=50):
602
+ # submit_btn = gr.Button(value="Send", variant="primary")
603
+ with gr.Row(elem_id="buttons") as button_row:
604
+ upvote_btn = gr.Button(value="👍 Upvote")
605
+ downvote_btn = gr.Button(value="👎 Downvote")
606
+ flag_btn = gr.Button(value="⚠️ Flag")
607
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
608
+ regenerate_btn = gr.Button(value="🔄 Regenerate")
609
+ clear_btn = gr.Button(value="🗑️ Clear")
610
+
611
+ # if not embed_mode:
612
+ # gr.Markdown(function_markdown)
613
+ # gr.Markdown(tos_markdown)
614
+ # gr.Markdown(learn_more_markdown)
615
+ url_params = gr.JSON(visible=False)
616
+
617
+ # Register listeners
618
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
619
+ upvote_btn.click(
620
+ upvote_last_response,
621
+ [state, model_selector],
622
+ [textbox, upvote_btn, downvote_btn, flag_btn],
623
+ queue=False
624
+ )
625
+ downvote_btn.click(
626
+ downvote_last_response,
627
+ [state, model_selector],
628
+ [textbox, upvote_btn, downvote_btn, flag_btn],
629
+ queue=False
630
+ )
631
+ flag_btn.click(
632
+ flag_last_response,
633
+ [state, model_selector],
634
+ [textbox, upvote_btn, downvote_btn, flag_btn],
635
+ queue=False
636
+ )
637
+
638
+ regenerate_btn.click(
639
+ regenerate,
640
+ [state, image_process_mode],
641
+ [state, chatbot, textbox, imagebox] + btn_list,
642
+ queue=False
643
+ ).then(
644
+ http_bot,
645
+ # [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
646
+ [state, model_selector],
647
+ [state, chatbot] + btn_list,
648
+ # concurrency_limit=concurrency_count
649
+ queue=False
650
+ )
651
+
652
+ clear_btn.click(
653
+ clear_history,
654
+ None,
655
+ [state, chatbot, textbox, imagebox] + btn_list,
656
+ queue=False
657
+ )
658
+
659
+ # textbox.submit(
660
+ # add_text,
661
+ # [state, textbox, imagebox, image_process_mode],
662
+ # [state, chatbot, textbox, imagebox] + btn_list,
663
+ # queue=False
664
+ # ).then(
665
+ # http_bot,
666
+ # # [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
667
+ # [state, model_selector],
668
+ # [state, chatbot] + btn_list,
669
+ # # concurrency_limit=concurrency_count
670
+ # )
671
+
672
+ submit_btn.click(
673
+ add_text,
674
+ [state, textbox, imagebox, image_process_mode],
675
+ [state, chatbot, textbox, imagebox] + btn_list,
676
+ queue=False
677
+ ).then(
678
+ http_bot,
679
+ # [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
680
+ [state, model_selector],
681
+ [state, chatbot] + btn_list,
682
+ # concurrency_limit=concurrency_count
683
+ queue=False
684
+ )
685
+
686
+ if args.model_list_mode == "once":
687
+ demo.load(
688
+ load_demo,
689
+ [url_params],
690
+ [state, model_selector],
691
+ _js=get_window_url_params
692
+ )
693
+ elif args.model_list_mode == "reload":
694
+ demo.load(
695
+ load_demo_refresh_model_list,
696
+ None,
697
+ [state, model_selector],
698
+ queue=False
699
+ )
700
+ else:
701
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
702
+
703
+ return demo
704
+
705
+
706
+ if __name__ == "__main__":
707
+ parser = argparse.ArgumentParser()
708
+ # parser.add_argument("--host", type=str, default="0.0.0.0")
709
+ # parser.add_argument("--port", type=int)
710
+ # parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
711
+ parser.add_argument("--concurrency-count", type=int, default=16)
712
+ parser.add_argument("--model-list-mode", type=str, default="reload",
713
+ choices=["once", "reload"])
714
+ parser.add_argument("--share", action="store_true")
715
+ parser.add_argument("--moderate", action="store_true")
716
+ parser.add_argument("--embed", action="store_true")
717
+ args = parser.parse_args()
718
+ logger.info(f"args: {args}")
719
+
720
+ logger.info(args)
721
+ demo = build_demo(args.embed)
722
+ demo.queue(
723
+ api_open=False
724
+ ).launch(
725
+ # server_name=args.host,
726
+ # server_port=args.port,
727
+ share=args.share
728
+ )
constants.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ PREDICT_TOKEN_INDEX = -300
10
+ DEFAULT_IMAGE_TOKEN = "<image>"
11
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
12
+ DEFAULT_IM_START_TOKEN = "<im_start>"
13
+ DEFAULT_IM_END_TOKEN = "<im_end>"
14
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
15
+ DEFAULT_PREDICT_TOKEN = "<predict>"
16
+
17
+ DESCRIPT_PROMPT = [
18
+ "Describe this image thoroughly.",
19
+ "Provide a detailed description in this picture.",
20
+ "Detail every aspect of what's in this picture.",
21
+ "Explain this image with precision and detail.",
22
+ "Give a comprehensive description of this visual.",
23
+ "Elaborate on the specifics within this image.",
24
+ "Offer a detailed account of this picture's contents.",
25
+ "Describe in detail what this image portrays.",
26
+ "Break down this image into detailed descriptions.",
27
+ "Provide a thorough description of the elements in this image."]
conversation.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+ GEMMA = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
+ sep: str = "###"
28
+ sep2: str = None
29
+ version: str = "Unknown"
30
+
31
+ skip_next: bool = False
32
+
33
+ def get_prompt(self):
34
+ messages = self.messages
35
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
36
+ messages = self.messages.copy()
37
+ init_role, init_msg = messages[0].copy()
38
+ init_msg = init_msg[0].replace("<image>", "").strip()
39
+ if 'mmtag' in self.version:
40
+ messages[0] = (init_role, init_msg)
41
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
+ messages.insert(1, (self.roles[1], "Received."))
43
+ else:
44
+ messages[0] = (init_role, "<image>\n" + init_msg)
45
+
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in messages:
49
+ if message:
50
+ if type(message) is tuple:
51
+ message = message[0]
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message = message[0]
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.MPT:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message = message[0]
71
+ ret += role + message + self.sep
72
+ else:
73
+ ret += role
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ if i == 0: message = wrap_sys(self.system) + message
87
+ if i % 2 == 0:
88
+ message = wrap_inst(message)
89
+ ret += self.sep + message
90
+ else:
91
+ ret += " " + message + " " + self.sep2
92
+ else:
93
+ ret += ""
94
+ ret = ret.lstrip(self.sep)
95
+ elif self.sep_style == SeparatorStyle.GEMMA:
96
+ seps = [self.sep, self.sep2]
97
+ ret = self.system + seps[0]
98
+ for i, (role, message) in enumerate(messages):
99
+ if message:
100
+ if type(message) is tuple:
101
+ message, _, _ = message
102
+ ret += "<start_of_turn>" + role + "\n" + message + "<end_of_turn>\n" + seps[i % 2]
103
+ else:
104
+ ret += "<start_of_turn>" + role + "\n"
105
+ elif self.sep_style == SeparatorStyle.PLAIN:
106
+ seps = [self.sep, self.sep2]
107
+ ret = self.system
108
+ for i, (role, message) in enumerate(messages):
109
+ if message:
110
+ if type(message) is tuple:
111
+ message, _, _ = message
112
+ ret += message + seps[i % 2]
113
+ else:
114
+ ret += ""
115
+ else:
116
+ raise ValueError(f"Invalid style: {self.sep_style}")
117
+
118
+ return ret
119
+
120
+ def append_message(self, role, message):
121
+ self.messages.append([role, message])
122
+
123
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
124
+ if image_process_mode == "Pad":
125
+ def expand2square(pil_img, background_color=(122, 116, 104)):
126
+ width, height = pil_img.size
127
+ if width == height:
128
+ return pil_img
129
+ elif width > height:
130
+ result = Image.new(pil_img.mode, (width, width), background_color)
131
+ result.paste(pil_img, (0, (width - height) // 2))
132
+ return result
133
+ else:
134
+ result = Image.new(pil_img.mode, (height, height), background_color)
135
+ result.paste(pil_img, ((height - width) // 2, 0))
136
+ return result
137
+ image = expand2square(image)
138
+ elif image_process_mode in ["Default", "Crop"]:
139
+ pass
140
+ elif image_process_mode == "Resize":
141
+ image = image.resize((336, 336))
142
+ else:
143
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
144
+ if max(image.size) > max_len:
145
+ max_hw, min_hw = max(image.size), min(image.size)
146
+ aspect_ratio = max_hw / min_hw
147
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
148
+ longest_edge = int(shortest_edge * aspect_ratio)
149
+ W, H = image.size
150
+ if H > W:
151
+ H, W = longest_edge, shortest_edge
152
+ else:
153
+ H, W = shortest_edge, longest_edge
154
+ image = image.resize((W, H))
155
+ if return_pil:
156
+ return image
157
+ else:
158
+ buffered = BytesIO()
159
+ image.save(buffered, format=image_format)
160
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
161
+ return img_b64_str
162
+
163
+ def get_images(self, return_pil=False):
164
+ images = []
165
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
166
+ if i % 2 == 0:
167
+ if type(msg) is tuple:
168
+ msg, image, image_process_mode = msg
169
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
170
+ images.append(image)
171
+ return images
172
+
173
+ def to_gradio_chatbot(self):
174
+ ret = []
175
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
176
+ if i % 2 == 0:
177
+ if type(msg) is tuple:
178
+ msg, image, image_process_mode = msg
179
+ img_b64_str = self.process_image(
180
+ image, "Default", return_pil=False,
181
+ image_format='JPEG')
182
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
183
+ msg = img_str + msg.replace('<image>', '').strip()
184
+ ret.append([msg, None])
185
+ else:
186
+ ret.append([msg, None])
187
+ else:
188
+ if type(msg) is tuple and len(msg) == 2:
189
+ msg, img_b64_str = msg
190
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
191
+ msg = msg.strip() + img_str
192
+ ret[-1][-1] = msg
193
+ return ret
194
+
195
+ def copy(self):
196
+ return Conversation(
197
+ system=self.system,
198
+ roles=self.roles,
199
+ messages=[[x, y] for x, y in self.messages],
200
+ offset=self.offset,
201
+ sep_style=self.sep_style,
202
+ sep=self.sep,
203
+ sep2=self.sep2,
204
+ version=self.version)
205
+
206
+ def dict(self):
207
+ if len(self.get_images()) > 0:
208
+ return {
209
+ "system": self.system,
210
+ "roles": self.roles,
211
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
212
+ "offset": self.offset,
213
+ "sep": self.sep,
214
+ "sep2": self.sep2,
215
+ }
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": self.messages,
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+
225
+
226
+ conv_vicuna_v0 = Conversation(
227
+ system="A chat between a curious human and an artificial intelligence assistant. "
228
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
229
+ roles=("Human", "Assistant"),
230
+ messages=(
231
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
232
+ ("Assistant",
233
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
234
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
235
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
236
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
237
+ "renewable and non-renewable energy sources:\n"
238
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
239
+ "energy sources are finite and will eventually run out.\n"
240
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
241
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
242
+ "and other negative effects.\n"
243
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
244
+ "have lower operational costs than non-renewable sources.\n"
245
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
246
+ "locations than non-renewable sources.\n"
247
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
248
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
249
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
250
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
251
+ ),
252
+ offset=2,
253
+ sep_style=SeparatorStyle.SINGLE,
254
+ sep="###",
255
+ )
256
+
257
+ conv_vicuna_v1 = Conversation(
258
+ system="A chat between a curious user and an artificial intelligence assistant. "
259
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
260
+ roles=("USER", "ASSISTANT"),
261
+ version="v1",
262
+ messages=(),
263
+ offset=0,
264
+ sep_style=SeparatorStyle.TWO,
265
+ sep=" ",
266
+ sep2="</s>",
267
+ )
268
+
269
+ conv_llama_2 = Conversation(
270
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
271
+
272
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
273
+ roles=("USER", "ASSISTANT"),
274
+ version="llama_v2",
275
+ messages=(),
276
+ offset=0,
277
+ sep_style=SeparatorStyle.LLAMA_2,
278
+ sep="<s>",
279
+ sep2="</s>",
280
+ )
281
+
282
+ conv_llava_llama_2 = Conversation(
283
+ system="You are a helpful language and vision assistant. "
284
+ "You are able to understand the visual content that the user provides, "
285
+ "and assist the user with a variety of tasks using natural language.",
286
+ roles=("USER", "ASSISTANT"),
287
+ version="llama_v2",
288
+ messages=(),
289
+ offset=0,
290
+ sep_style=SeparatorStyle.LLAMA_2,
291
+ sep="<s>",
292
+ sep2="</s>",
293
+ )
294
+
295
+ conv_mpt = Conversation(
296
+ system="""<|im_start|>system
297
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
298
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
299
+ version="mpt",
300
+ messages=(),
301
+ offset=0,
302
+ sep_style=SeparatorStyle.MPT,
303
+ sep="<|im_end|>",
304
+ )
305
+
306
+ conv_llava_plain = Conversation(
307
+ system="",
308
+ roles=("", ""),
309
+ messages=(
310
+ ),
311
+ offset=0,
312
+ sep_style=SeparatorStyle.PLAIN,
313
+ sep="\n",
314
+ )
315
+
316
+ conv_llava_v0 = Conversation(
317
+ system="A chat between a curious human and an artificial intelligence assistant. "
318
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
319
+ roles=("Human", "Assistant"),
320
+ messages=(
321
+ ),
322
+ offset=0,
323
+ sep_style=SeparatorStyle.SINGLE,
324
+ sep="###",
325
+ )
326
+
327
+ conv_llava_v0_mmtag = Conversation(
328
+ system="A chat between a curious user and an artificial intelligence assistant. "
329
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
330
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
331
+ roles=("Human", "Assistant"),
332
+ messages=(
333
+ ),
334
+ offset=0,
335
+ sep_style=SeparatorStyle.SINGLE,
336
+ sep="###",
337
+ version="v0_mmtag",
338
+ )
339
+
340
+ conv_llava_v1 = Conversation(
341
+ system="A chat between a curious human and an artificial intelligence assistant. "
342
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
343
+ roles=("USER", "ASSISTANT"),
344
+ version="v1",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.TWO,
348
+ sep=" ",
349
+ sep2="</s>",
350
+ )
351
+
352
+ conv_vicuna_imgsp_v1 = Conversation(
353
+ system="A chat between a curious user and an artificial intelligence assistant. "
354
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
355
+ roles=("USER", "ASSISTANT"),
356
+ version="imgsp_v1",
357
+ messages=(),
358
+ offset=0,
359
+ sep_style=SeparatorStyle.TWO,
360
+ sep=" ",
361
+ sep2="</s>",
362
+ )
363
+
364
+ conv_llava_plain_guided = Conversation(
365
+ system="",
366
+ roles=("", ""),
367
+ version="plain_guided",
368
+ messages=(
369
+ ),
370
+ offset=0,
371
+ sep_style=SeparatorStyle.PLAIN,
372
+ sep="\n",
373
+ )
374
+
375
+ conv_llava_v1_mmtag = Conversation(
376
+ system="A chat between a curious user and an artificial intelligence assistant. "
377
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
378
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
379
+ roles=("USER", "ASSISTANT"),
380
+ messages=(),
381
+ offset=0,
382
+ sep_style=SeparatorStyle.TWO,
383
+ sep=" ",
384
+ sep2="</s>",
385
+ version="v1_mmtag",
386
+ )
387
+
388
+ conv_phi_2 = Conversation(
389
+ system="A chat between a curious user and an artificial intelligence assistant. "
390
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
391
+ roles=("USER", "ASSISTANT"),
392
+ version="phi2",
393
+ messages=(),
394
+ offset=0,
395
+ sep_style=SeparatorStyle.TWO,
396
+ sep=" ",
397
+ sep2="<|endoftext|>",
398
+ )
399
+
400
+ conv_mistral_instruct = Conversation(
401
+ system="",
402
+ roles=("USER", "ASSISTANT"),
403
+ version="llama_v2",
404
+ messages=(),
405
+ offset=0,
406
+ sep_style=SeparatorStyle.LLAMA_2,
407
+ sep="<s>",
408
+ sep2="</s>",
409
+ )
410
+
411
+ conv_gemma = Conversation(
412
+ system="",
413
+ roles=("user", "model"),
414
+ version="gemma",
415
+ messages=(),
416
+ offset=0,
417
+ sep_style=SeparatorStyle.GEMMA,
418
+ sep="",
419
+ sep2="<eos>",
420
+ )
421
+
422
+ conv_chatml_direct = Conversation(
423
+ system="""<|im_start|>system
424
+ Answer the questions.""",
425
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
426
+ version="mpt",
427
+ messages=(),
428
+ offset=0,
429
+ sep_style=SeparatorStyle.MPT,
430
+ sep="<|im_end|>",
431
+ )
432
+
433
+ default_conversation = conv_vicuna_v1
434
+ conv_templates = {
435
+ "default": conv_vicuna_v0,
436
+ "v0": conv_vicuna_v0,
437
+ "v1": conv_vicuna_v1,
438
+ "vicuna_v1": conv_vicuna_v1,
439
+ "phi_2": conv_phi_2,
440
+ "gemma": conv_gemma,
441
+ "llama_2": conv_llama_2,
442
+ "imgsp_v1": conv_vicuna_imgsp_v1,
443
+ "plain_guided": conv_llava_plain_guided,
444
+ "mistral_instruct": conv_mistral_instruct,
445
+ "chatml_direct": conv_chatml_direct,
446
+ "mistral_direct": conv_chatml_direct,
447
+ "plain": conv_llava_plain,
448
+ "v0_plain": conv_llava_plain,
449
+ "llava_v0": conv_llava_v0,
450
+ "v0_mmtag": conv_llava_v0_mmtag,
451
+ "llava_v1": conv_llava_v1,
452
+ "v1_mmtag": conv_llava_v1_mmtag,
453
+ "llava_llama_2": conv_llava_llama_2,
454
+
455
+ "mpt": conv_mpt,
456
+ }
457
+
458
+
459
+ if __name__ == "__main__":
460
+ print(default_conversation.get_prompt())
examples/0.png ADDED
examples/100.png ADDED
examples/134.png ADDED
examples/19.png ADDED
examples/36.png ADDED
examples/44.png ADDED
examples/602.png ADDED
examples/7316.png ADDED
examples/929.png ADDED
requirements.txt ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.3
4
+ aiosignal==1.3.1
5
+ altair==5.3.0
6
+ annotated-types==0.6.0
7
+ anyio==4.3.0
8
+ astor==0.8.1
9
+ async-timeout==4.0.3
10
+ attrdict==2.0.1
11
+ attrs==23.2.0
12
+ Babel==2.14.0
13
+ bce-python-sdk==0.9.6
14
+ beautifulsoup4==4.12.3
15
+ bitsandbytes==0.41.0
16
+ blinker==1.7.0
17
+ cachetools==5.3.3
18
+ certifi==2024.2.2
19
+ charset-normalizer==3.3.2
20
+ click==8.1.7
21
+ cmake==3.29.0.1
22
+ contourpy==1.2.0
23
+ cssselect==1.2.0
24
+ cssutils==2.10.2
25
+ cycler==0.12.1
26
+ Cython==3.0.10
27
+ decorator==5.1.1
28
+ # deepspeed==0.14.0
29
+ diffusers==0.26.3
30
+ distro==1.9.0
31
+ einops==0.6.1
32
+ einops-exts==0.0.4
33
+ et-xmlfile==1.1.0
34
+ exceptiongroup==1.2.0
35
+ fastapi==0.110.0
36
+ ffmpy==0.3.2
37
+ filelock==3.13.3
38
+ fire==0.6.0
39
+ # flash-attn==2.5.6
40
+ Flask==3.0.2
41
+ flask-babel==4.0.0
42
+ fonttools==4.50.0
43
+ frozenlist==1.4.1
44
+ fsspec==2024.3.1
45
+ ftfy==6.2.0
46
+ future==1.0.0
47
+ gradio==4.24.0
48
+ gradio_client==0.14.0
49
+ h11==0.14.0
50
+ hjson==3.1.0
51
+ httpcore==1.0.5
52
+ httpx==0.27.0
53
+ huggingface-hub==0.22.2
54
+ idna==3.6
55
+ imageio==2.34.0
56
+ imgaug==0.4.0
57
+ importlib_metadata==7.1.0
58
+ importlib_resources==6.4.0
59
+ itsdangerous==2.1.2
60
+ Jinja2==3.1.3
61
+ joblib==1.3.2
62
+ jsonschema==4.21.1
63
+ jsonschema-specifications==2023.12.1
64
+ kiwisolver==1.4.5
65
+ lazy_loader==0.3
66
+ linkify-it-py==2.0.3
67
+ lit==18.1.2
68
+ lmdb==1.4.1
69
+ lxml==5.2.0
70
+ markdown-it-py==2.2.0
71
+ markdown2==2.4.13
72
+ MarkupSafe==2.1.5
73
+ matplotlib==3.8.3
74
+ mdit-py-plugins==0.3.3
75
+ mdurl==0.1.2
76
+ mpmath==1.3.0
77
+ multidict==6.0.5
78
+ networkx==3.2.1
79
+ ninja==1.11.1.1
80
+ numpy==1.26.4
81
+ open-clip-torch==2.24.0
82
+ openai==1.16.0
83
+ opencv-contrib-python==4.6.0.66
84
+ opencv-python==4.6.0.66
85
+ opencv-python-headless==4.9.0.80
86
+ openpyxl==3.1.2
87
+ opt-einsum==3.3.0
88
+ orjson==3.10.0
89
+ packaging==24.0
90
+ # paddleocr==2.7.0.3
91
+ # paddlepaddle==2.5.2
92
+ pandas==2.2.1
93
+ pdf2docx==0.5.8
94
+ peft==0.4.0
95
+ pillow==10.3.0
96
+ premailer==3.10.0
97
+ protobuf==5.26.1
98
+ psutil==5.9.8
99
+ py-cpuinfo==9.0.0
100
+ pyclipper==1.3.0.post5
101
+ pycryptodome==3.20.0
102
+ pydantic==2.6.4
103
+ pydantic_core==2.16.3
104
+ pydub==0.25.1
105
+ Pygments==2.17.2
106
+ PyMuPDF==1.20.2
107
+ PyMuPDFb==1.24.0
108
+ pynvml==11.5.0
109
+ pyparsing==3.1.2
110
+ python-dateutil==2.9.0.post0
111
+ python-docx==1.1.0
112
+ python-multipart==0.0.9
113
+ pytz==2024.1
114
+ PyYAML==6.0.1
115
+ rapidfuzz==3.7.0
116
+ rarfile==4.1
117
+ referencing==0.34.0
118
+ regex==2023.12.25
119
+ requests==2.31.0
120
+ rich==13.7.1
121
+ rpds-py==0.18.0
122
+ ruff==0.3.5
123
+ safetensors==0.4.2
124
+ scikit-image==0.22.0
125
+ scikit-learn==1.2.2
126
+ scipy==1.12.0
127
+ semantic-version==2.10.0
128
+ sentencepiece==0.1.99
129
+ shapely==2.0.3
130
+ shellingham==1.5.4
131
+ shortuuid==1.0.13
132
+ six==1.16.0
133
+ sniffio==1.3.1
134
+ soupsieve==2.5
135
+ starlette==0.36.3
136
+ svgwrite==1.4.3
137
+ sympy==1.12
138
+ termcolor==2.4.0
139
+ threadpoolctl==3.4.0
140
+ tifffile==2024.2.12
141
+ timm==0.9.16
142
+ tokenizers==0.15.0
143
+ tomlkit==0.12.0
144
+ toolz==0.12.1
145
+ torch==2.0.1
146
+ torchvision==0.15.2
147
+ tqdm==4.66.2
148
+ # transformers==4.36.2
149
+ # triton==2.0.0
150
+ typer==0.12.0
151
+ typer-cli==0.12.0
152
+ typer-slim==0.12.0
153
+ typing_extensions==4.10.0
154
+ tzdata==2024.1
155
+ uc-micro-py==1.0.3
156
+ urllib3==2.2.1
157
+ uvicorn==0.29.0
158
+ visualdl==2.5.3
159
+ wavedrom==2.0.3.post3
160
+ wcwidth==0.2.13
161
+ websockets==11.0.3
162
+ Werkzeug==3.0.2
163
+ yarl==1.9.4
164
+ zipp==3.18.1
165
+ llama-cpp-python==0.2.59
utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True, encoding='UTF-8')
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper docs:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"