Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	GSK-2909-handle-input-too-long-error (#165)
Browse files- fix labels not found; handle input too long (6457b7ab3e357b83e70bad7268307a81d2154bfa)
- Remove leading spaces (b709c217e8e05020af6b864792c2941358e29dfc)
- Remove tailing spaces (6dbdec2b3e7c630a187b0171e19325eb4e2f1126)
- Remove leading spaces (c1a5e7ecbc5505f334a5bf1a458bf1d7ffaa00e3)
Co-authored-by: zcy <[email protected]>
- app_text_classification.py +1 -0
- text_classification.py +12 -0
- text_classification_ui_helpers.py +3 -3
    	
        app_text_classification.py
    CHANGED
    
    | @@ -201,6 +201,7 @@ def get_demo(): | |
| 201 | 
             
                gr.on(
         | 
| 202 | 
             
                    triggers=[
         | 
| 203 | 
             
                        model_id_input.change,
         | 
|  | |
| 204 | 
             
                        dataset_id_input.change,
         | 
| 205 | 
             
                        dataset_config_input.change,
         | 
| 206 | 
             
                        dataset_split_input.change,
         | 
|  | |
| 201 | 
             
                gr.on(
         | 
| 202 | 
             
                    triggers=[
         | 
| 203 | 
             
                        model_id_input.change,
         | 
| 204 | 
            +
                        model_id_input.input,
         | 
| 205 | 
             
                        dataset_id_input.change,
         | 
| 206 | 
             
                        dataset_config_input.change,
         | 
| 207 | 
             
                        dataset_split_input.change,
         | 
    	
        text_classification.py
    CHANGED
    
    | @@ -28,10 +28,14 @@ def get_labels_and_features_from_dataset(ds): | |
| 28 | 
             
                    if len(label_keys) == 0: # no labels found
         | 
| 29 | 
             
                        # return everything for post processing
         | 
| 30 | 
             
                        return list(dataset_features.keys()), list(dataset_features.keys()), None
         | 
|  | |
|  | |
| 31 | 
             
                    if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
         | 
| 32 | 
             
                        if hasattr(dataset_features[label_keys[0]], "feature"):
         | 
| 33 | 
             
                            label_feat = dataset_features[label_keys[0]].feature
         | 
| 34 | 
             
                            labels = label_feat.names
         | 
|  | |
|  | |
| 35 | 
             
                    else:
         | 
| 36 | 
             
                        labels = dataset_features[label_keys[0]].names
         | 
| 37 | 
             
                    return labels, features, label_keys
         | 
| @@ -83,9 +87,17 @@ def hf_inference_api(model_id, hf_token, payload): | |
| 83 | 
             
                url = f"{hf_inference_api_endpoint}/models/{model_id}"
         | 
| 84 | 
             
                headers = {"Authorization": f"Bearer {hf_token}"}
         | 
| 85 | 
             
                response = requests.post(url, headers=headers, json=payload)
         | 
|  | |
| 86 | 
             
                if not hasattr(response, "status_code") or response.status_code != 200:
         | 
| 87 | 
             
                    logger.warning(f"Request to inference API returns {response}")
         | 
|  | |
| 88 | 
             
                try:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 89 | 
             
                    return response.json()
         | 
| 90 | 
             
                except Exception:
         | 
| 91 | 
             
                    return {"error": response.content}
         | 
|  | |
| 28 | 
             
                    if len(label_keys) == 0: # no labels found
         | 
| 29 | 
             
                        # return everything for post processing
         | 
| 30 | 
             
                        return list(dataset_features.keys()), list(dataset_features.keys()), None
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    labels = None
         | 
| 33 | 
             
                    if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
         | 
| 34 | 
             
                        if hasattr(dataset_features[label_keys[0]], "feature"):
         | 
| 35 | 
             
                            label_feat = dataset_features[label_keys[0]].feature
         | 
| 36 | 
             
                            labels = label_feat.names
         | 
| 37 | 
            +
                        else:
         | 
| 38 | 
            +
                            labels = ds.unique(label_keys[0])
         | 
| 39 | 
             
                    else:
         | 
| 40 | 
             
                        labels = dataset_features[label_keys[0]].names
         | 
| 41 | 
             
                    return labels, features, label_keys
         | 
|  | |
| 87 | 
             
                url = f"{hf_inference_api_endpoint}/models/{model_id}"
         | 
| 88 | 
             
                headers = {"Authorization": f"Bearer {hf_token}"}
         | 
| 89 | 
             
                response = requests.post(url, headers=headers, json=payload)
         | 
| 90 | 
            +
             | 
| 91 | 
             
                if not hasattr(response, "status_code") or response.status_code != 200:
         | 
| 92 | 
             
                    logger.warning(f"Request to inference API returns {response}")
         | 
| 93 | 
            +
             | 
| 94 | 
             
                try:
         | 
| 95 | 
            +
                    output = response.json()
         | 
| 96 | 
            +
                    if "error" in output and "Input is too long" in output["error"]:
         | 
| 97 | 
            +
                      payload.update({"parameters": {"truncation": True, "max_length": 512}})
         | 
| 98 | 
            +
                      response = requests.post(url, headers=headers, json=payload)
         | 
| 99 | 
            +
                      if not hasattr(response, "status_code") or response.status_code != 200:
         | 
| 100 | 
            +
                          logger.warning(f"Request to inference API returns {response}")
         | 
| 101 | 
             
                    return response.json()
         | 
| 102 | 
             
                except Exception:
         | 
| 103 | 
             
                    return {"error": response.content}
         | 
    	
        text_classification_ui_helpers.py
    CHANGED
    
    | @@ -341,8 +341,8 @@ def align_columns_and_show_prediction( | |
| 341 | 
             
                ):
         | 
| 342 | 
             
                    return (
         | 
| 343 | 
             
                        gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
         | 
| 344 | 
            -
                        gr.update(visible= | 
| 345 | 
            -
                        gr.update(visible= | 
| 346 | 
             
                        gr.update(visible=True, open=True),
         | 
| 347 | 
             
                        gr.update(interactive=(run_inference and inference_token != "")),
         | 
| 348 | 
             
                        "",
         | 
| @@ -351,7 +351,7 @@ def align_columns_and_show_prediction( | |
| 351 |  | 
| 352 | 
             
                return (
         | 
| 353 | 
             
                    gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
         | 
| 354 | 
            -
                    gr.update(value=prediction_input, lines=len(prediction_input)//225 + 1, visible=True),
         | 
| 355 | 
             
                    gr.update(value=prediction_response, visible=True),
         | 
| 356 | 
             
                    gr.update(visible=True, open=False),
         | 
| 357 | 
             
                    gr.update(interactive=(run_inference and inference_token != "")),
         | 
|  | |
| 341 | 
             
                ):
         | 
| 342 | 
             
                    return (
         | 
| 343 | 
             
                        gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
         | 
| 344 | 
            +
                        gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
         | 
| 345 | 
            +
                        gr.update(value=prediction_response, visible=True),
         | 
| 346 | 
             
                        gr.update(visible=True, open=True),
         | 
| 347 | 
             
                        gr.update(interactive=(run_inference and inference_token != "")),
         | 
| 348 | 
             
                        "",
         | 
|  | |
| 351 |  | 
| 352 | 
             
                return (
         | 
| 353 | 
             
                    gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
         | 
| 354 | 
            +
                    gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
         | 
| 355 | 
             
                    gr.update(value=prediction_response, visible=True),
         | 
| 356 | 
             
                    gr.update(visible=True, open=False),
         | 
| 357 | 
             
                    gr.update(interactive=(run_inference and inference_token != "")),
         | 

