| from transformers import TextGenerationPipeline | |
| from transformers.pipelines.text_generation import ReturnType | |
| STYLE = "<|prompt|>{instruction}</s><|answer|>" | |
| class H2OTextGenerationPipeline(TextGenerationPipeline): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.prompt = STYLE | |
| def preprocess( | |
| self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs | |
| ): | |
| prompt_text = self.prompt.format(instruction=prompt_text) | |
| return super().preprocess( | |
| prompt_text, | |
| prefix=prefix, | |
| handle_long_generation=handle_long_generation, | |
| **generate_kwargs, | |
| ) | |
| def postprocess( | |
| self, | |
| model_outputs, | |
| return_type=ReturnType.FULL_TEXT, | |
| clean_up_tokenization_spaces=True, | |
| ): | |
| records = super().postprocess( | |
| model_outputs, | |
| return_type=return_type, | |
| clean_up_tokenization_spaces=clean_up_tokenization_spaces, | |
| ) | |
| for rec in records: | |
| rec["generated_text"] = ( | |
| rec["generated_text"] | |
| .split("<|answer|>")[1] | |
| .strip() | |
| .split("<|prompt|>")[0] | |
| .strip() | |
| ) | |
| return records |