Spaces:
Running
Running
Commit
·
fc361e6
1
Parent(s):
e71d7e4
Format text classification and use warning
Browse files- run_jobs.py +1 -1
- text_classification.py +33 -16
- text_classification_ui_helpers.py +9 -9
run_jobs.py
CHANGED
|
@@ -69,7 +69,7 @@ def prepare_env_and_get_command(
|
|
| 69 |
)
|
| 70 |
logger.info(f"Using {executable} as executable")
|
| 71 |
except Exception as e:
|
| 72 |
-
logger.
|
| 73 |
executable = "giskard_scanner"
|
| 74 |
|
| 75 |
command = [
|
|
|
|
| 69 |
)
|
| 70 |
logger.info(f"Using {executable} as executable")
|
| 71 |
except Exception as e:
|
| 72 |
+
logger.warning(f"Create env failed due to {e}, using the current env as fallback.")
|
| 73 |
executable = "giskard_scanner"
|
| 74 |
|
| 75 |
command = [
|
text_classification.py
CHANGED
|
@@ -14,6 +14,7 @@ AUTH_CHECK_URL = "https://huggingface.co/api/whoami-v2"
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__file__)
|
| 16 |
|
|
|
|
| 17 |
class HuggingFaceInferenceAPIResponse:
|
| 18 |
def __init__(self, message):
|
| 19 |
self.message = message
|
|
@@ -25,7 +26,7 @@ def get_labels_and_features_from_dataset(ds):
|
|
| 25 |
label_keys = [i for i in dataset_features.keys() if i.startswith("label")]
|
| 26 |
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
| 27 |
|
| 28 |
-
if len(label_keys) == 0:
|
| 29 |
# return everything for post processing
|
| 30 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
| 31 |
|
|
@@ -40,11 +41,10 @@ def get_labels_and_features_from_dataset(ds):
|
|
| 40 |
labels = dataset_features[label_keys[0]].names
|
| 41 |
return labels, features, label_keys
|
| 42 |
except Exception as e:
|
| 43 |
-
logging.warning(
|
| 44 |
-
f"Get Labels/Features Failed for dataset: {e}"
|
| 45 |
-
)
|
| 46 |
return None, None, None
|
| 47 |
|
|
|
|
| 48 |
def check_model_task(model_id):
|
| 49 |
# check if model is valid on huggingface
|
| 50 |
try:
|
|
@@ -55,6 +55,7 @@ def check_model_task(model_id):
|
|
| 55 |
except Exception:
|
| 56 |
return None
|
| 57 |
|
|
|
|
| 58 |
def get_model_labels(model_id, example_input):
|
| 59 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
| 60 |
payload = {"inputs": example_input, "options": {"use_cache": True}}
|
|
@@ -63,6 +64,7 @@ def get_model_labels(model_id, example_input):
|
|
| 63 |
return None
|
| 64 |
return extract_from_response(response, "label")
|
| 65 |
|
|
|
|
| 66 |
def extract_from_response(data, key):
|
| 67 |
results = []
|
| 68 |
|
|
@@ -80,6 +82,7 @@ def extract_from_response(data, key):
|
|
| 80 |
|
| 81 |
return results
|
| 82 |
|
|
|
|
| 83 |
def hf_inference_api(model_id, hf_token, payload):
|
| 84 |
hf_inference_api_endpoint = os.environ.get(
|
| 85 |
"HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
|
|
@@ -94,19 +97,26 @@ def hf_inference_api(model_id, hf_token, payload):
|
|
| 94 |
try:
|
| 95 |
output = response.json()
|
| 96 |
if "error" in output and "Input is too long" in output["error"]:
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
return response.json()
|
| 102 |
except Exception:
|
| 103 |
return {"error": response.content}
|
| 104 |
-
|
|
|
|
| 105 |
def preload_hf_inference_api(model_id):
|
| 106 |
-
payload = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
| 108 |
hf_inference_api(model_id, hf_token, payload)
|
| 109 |
|
|
|
|
| 110 |
def check_model_pipeline(model_id):
|
| 111 |
try:
|
| 112 |
task = huggingface_hub.model_info(model_id).pipeline_tag
|
|
@@ -279,6 +289,7 @@ def check_dataset_features_validity(d_id, config, split):
|
|
| 279 |
|
| 280 |
return df, dataset_features
|
| 281 |
|
|
|
|
| 282 |
def select_the_first_string_column(ds):
|
| 283 |
for feature in ds.features.keys():
|
| 284 |
if isinstance(ds[0][feature], str):
|
|
@@ -286,13 +297,17 @@ def select_the_first_string_column(ds):
|
|
| 286 |
return None
|
| 287 |
|
| 288 |
|
| 289 |
-
def get_example_prediction(
|
|
|
|
|
|
|
| 290 |
# get a sample prediction from the model on the dataset
|
| 291 |
prediction_input = None
|
| 292 |
prediction_result = None
|
| 293 |
try:
|
| 294 |
# Use the first item to test prediction
|
| 295 |
-
ds = datasets.load_dataset(
|
|
|
|
|
|
|
| 296 |
if "text" not in ds.features.keys():
|
| 297 |
# Dataset does not have text column
|
| 298 |
prediction_input = ds[0][select_the_first_string_column(ds)]
|
|
@@ -305,10 +320,12 @@ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split,
|
|
| 305 |
if isinstance(results, dict) and "error" in results.keys():
|
| 306 |
if "estimated_time" in results.keys():
|
| 307 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
| 308 |
-
f"Estimated time: {int(results['estimated_time'])}s. Please try again later."
|
|
|
|
| 309 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
| 310 |
-
f"Inference Error: {results['error']}."
|
| 311 |
-
|
|
|
|
| 312 |
while isinstance(results, list):
|
| 313 |
if isinstance(results[0], dict):
|
| 314 |
break
|
|
@@ -402,4 +419,4 @@ def check_hf_token_validity(hf_token):
|
|
| 402 |
response = requests.get(AUTH_CHECK_URL, headers=headers)
|
| 403 |
if response.status_code != 200:
|
| 404 |
return False
|
| 405 |
-
return True
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__file__)
|
| 16 |
|
| 17 |
+
|
| 18 |
class HuggingFaceInferenceAPIResponse:
|
| 19 |
def __init__(self, message):
|
| 20 |
self.message = message
|
|
|
|
| 26 |
label_keys = [i for i in dataset_features.keys() if i.startswith("label")]
|
| 27 |
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
| 28 |
|
| 29 |
+
if len(label_keys) == 0: # no labels found
|
| 30 |
# return everything for post processing
|
| 31 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
| 32 |
|
|
|
|
| 41 |
labels = dataset_features[label_keys[0]].names
|
| 42 |
return labels, features, label_keys
|
| 43 |
except Exception as e:
|
| 44 |
+
logging.warning(f"Get Labels/Features Failed for dataset: {e}")
|
|
|
|
|
|
|
| 45 |
return None, None, None
|
| 46 |
|
| 47 |
+
|
| 48 |
def check_model_task(model_id):
|
| 49 |
# check if model is valid on huggingface
|
| 50 |
try:
|
|
|
|
| 55 |
except Exception:
|
| 56 |
return None
|
| 57 |
|
| 58 |
+
|
| 59 |
def get_model_labels(model_id, example_input):
|
| 60 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
| 61 |
payload = {"inputs": example_input, "options": {"use_cache": True}}
|
|
|
|
| 64 |
return None
|
| 65 |
return extract_from_response(response, "label")
|
| 66 |
|
| 67 |
+
|
| 68 |
def extract_from_response(data, key):
|
| 69 |
results = []
|
| 70 |
|
|
|
|
| 82 |
|
| 83 |
return results
|
| 84 |
|
| 85 |
+
|
| 86 |
def hf_inference_api(model_id, hf_token, payload):
|
| 87 |
hf_inference_api_endpoint = os.environ.get(
|
| 88 |
"HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
|
|
|
|
| 97 |
try:
|
| 98 |
output = response.json()
|
| 99 |
if "error" in output and "Input is too long" in output["error"]:
|
| 100 |
+
payload.update({"parameters": {"truncation": True, "max_length": 512}})
|
| 101 |
+
response = requests.post(url, headers=headers, json=payload)
|
| 102 |
+
if not hasattr(response, "status_code") or response.status_code != 200:
|
| 103 |
+
logger.warning(f"Request to inference API returns {response}")
|
| 104 |
return response.json()
|
| 105 |
except Exception:
|
| 106 |
return {"error": response.content}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
def preload_hf_inference_api(model_id):
|
| 110 |
+
payload = {
|
| 111 |
+
"inputs": "This is a test",
|
| 112 |
+
"options": {
|
| 113 |
+
"use_cache": True,
|
| 114 |
+
},
|
| 115 |
+
}
|
| 116 |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
|
| 117 |
hf_inference_api(model_id, hf_token, payload)
|
| 118 |
|
| 119 |
+
|
| 120 |
def check_model_pipeline(model_id):
|
| 121 |
try:
|
| 122 |
task = huggingface_hub.model_info(model_id).pipeline_tag
|
|
|
|
| 289 |
|
| 290 |
return df, dataset_features
|
| 291 |
|
| 292 |
+
|
| 293 |
def select_the_first_string_column(ds):
|
| 294 |
for feature in ds.features.keys():
|
| 295 |
if isinstance(ds[0][feature], str):
|
|
|
|
| 297 |
return None
|
| 298 |
|
| 299 |
|
| 300 |
+
def get_example_prediction(
|
| 301 |
+
model_id, dataset_id, dataset_config, dataset_split, hf_token
|
| 302 |
+
):
|
| 303 |
# get a sample prediction from the model on the dataset
|
| 304 |
prediction_input = None
|
| 305 |
prediction_result = None
|
| 306 |
try:
|
| 307 |
# Use the first item to test prediction
|
| 308 |
+
ds = datasets.load_dataset(
|
| 309 |
+
dataset_id, dataset_config, split=dataset_split, trust_remote_code=True
|
| 310 |
+
)
|
| 311 |
if "text" not in ds.features.keys():
|
| 312 |
# Dataset does not have text column
|
| 313 |
prediction_input = ds[0][select_the_first_string_column(ds)]
|
|
|
|
| 320 |
if isinstance(results, dict) and "error" in results.keys():
|
| 321 |
if "estimated_time" in results.keys():
|
| 322 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
| 323 |
+
f"Estimated time: {int(results['estimated_time'])}s. Please try again later."
|
| 324 |
+
)
|
| 325 |
return prediction_input, HuggingFaceInferenceAPIResponse(
|
| 326 |
+
f"Inference Error: {results['error']}."
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
while isinstance(results, list):
|
| 330 |
if isinstance(results[0], dict):
|
| 331 |
break
|
|
|
|
| 419 |
response = requests.get(AUTH_CHECK_URL, headers=headers)
|
| 420 |
if response.status_code != 200:
|
| 421 |
return False
|
| 422 |
+
return True
|
text_classification_ui_helpers.py
CHANGED
|
@@ -63,7 +63,7 @@ def get_dataset_splits(dataset_id, dataset_config):
|
|
| 63 |
splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
|
| 64 |
return gr.update(choices=splits, value=splits[0], visible=True)
|
| 65 |
except Exception as e:
|
| 66 |
-
logger.
|
| 67 |
return gr.update(visible=False)
|
| 68 |
|
| 69 |
def check_dataset(dataset_id):
|
|
@@ -83,7 +83,7 @@ def check_dataset(dataset_id):
|
|
| 83 |
""
|
| 84 |
)
|
| 85 |
except Exception as e:
|
| 86 |
-
logger.
|
| 87 |
if "doesn't exist" in str(e):
|
| 88 |
gr.Warning(get_dataset_fetch_error_raw(e))
|
| 89 |
if "forbidden" in str(e).lower(): # GSK-2770
|
|
@@ -232,7 +232,7 @@ def precheck_model_ds_enable_example_btn(
|
|
| 232 |
)
|
| 233 |
except Exception as e:
|
| 234 |
# Config or split wrong
|
| 235 |
-
logger.
|
| 236 |
return (
|
| 237 |
gr.update(interactive=False),
|
| 238 |
gr.update(visible=False),
|
|
@@ -372,30 +372,30 @@ def check_column_mapping_keys_validity(all_mappings):
|
|
| 372 |
|
| 373 |
def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, dataset_split):
|
| 374 |
if inference_token == "":
|
| 375 |
-
logger.
|
| 376 |
return gr.update(interactive=False)
|
| 377 |
if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
|
| 378 |
-
logger.
|
| 379 |
return gr.update(interactive=False)
|
| 380 |
|
| 381 |
all_mappings = read_column_mapping(uid)
|
| 382 |
if not check_column_mapping_keys_validity(all_mappings):
|
| 383 |
-
logger.
|
| 384 |
return gr.update(interactive=False)
|
| 385 |
|
| 386 |
if not check_hf_token_validity(inference_token):
|
| 387 |
-
logger.
|
| 388 |
return gr.update(interactive=False)
|
| 389 |
return gr.update(interactive=True)
|
| 390 |
|
| 391 |
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
|
| 392 |
label_mapping = {}
|
| 393 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
| 394 |
-
logger.
|
| 395 |
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
|
| 396 |
|
| 397 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
| 398 |
-
logger.
|
| 399 |
\nall_mappings: {all_mappings}\nds_features: {ds_features}""")
|
| 400 |
|
| 401 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|
|
|
|
| 63 |
splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
|
| 64 |
return gr.update(choices=splits, value=splits[0], visible=True)
|
| 65 |
except Exception as e:
|
| 66 |
+
logger.warning(f"Check your dataset {dataset_id} and config {dataset_config}: {e}")
|
| 67 |
return gr.update(visible=False)
|
| 68 |
|
| 69 |
def check_dataset(dataset_id):
|
|
|
|
| 83 |
""
|
| 84 |
)
|
| 85 |
except Exception as e:
|
| 86 |
+
logger.warning(f"Check your dataset {dataset_id}: {e}")
|
| 87 |
if "doesn't exist" in str(e):
|
| 88 |
gr.Warning(get_dataset_fetch_error_raw(e))
|
| 89 |
if "forbidden" in str(e).lower(): # GSK-2770
|
|
|
|
| 232 |
)
|
| 233 |
except Exception as e:
|
| 234 |
# Config or split wrong
|
| 235 |
+
logger.warning(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}")
|
| 236 |
return (
|
| 237 |
gr.update(interactive=False),
|
| 238 |
gr.update(visible=False),
|
|
|
|
| 372 |
|
| 373 |
def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, dataset_split):
|
| 374 |
if inference_token == "":
|
| 375 |
+
logger.warning("Inference API is not enabled")
|
| 376 |
return gr.update(interactive=False)
|
| 377 |
if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
|
| 378 |
+
logger.warning("Model id or dataset id is not selected")
|
| 379 |
return gr.update(interactive=False)
|
| 380 |
|
| 381 |
all_mappings = read_column_mapping(uid)
|
| 382 |
if not check_column_mapping_keys_validity(all_mappings):
|
| 383 |
+
logger.warning("Column mapping is not valid")
|
| 384 |
return gr.update(interactive=False)
|
| 385 |
|
| 386 |
if not check_hf_token_validity(inference_token):
|
| 387 |
+
logger.warning("HF token is not valid")
|
| 388 |
return gr.update(interactive=False)
|
| 389 |
return gr.update(interactive=True)
|
| 390 |
|
| 391 |
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
|
| 392 |
label_mapping = {}
|
| 393 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
| 394 |
+
logger.warning(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
| 395 |
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
|
| 396 |
|
| 397 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
| 398 |
+
logger.warning(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
| 399 |
\nall_mappings: {all_mappings}\nds_features: {ds_features}""")
|
| 400 |
|
| 401 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|