|
import json
|
|
import re
|
|
import requests
|
|
|
|
from tclogger import logger
|
|
from constants.models import MODEL_MAP, STOP_SEQUENCES_MAP
|
|
from constants.envs import PROXIES
|
|
from messagers.message_outputer import OpenaiStreamOutputer
|
|
from messagers.token_checker import TokenChecker
|
|
|
|
|
|
class HuggingfaceStreamer:
|
|
def __init__(self, model: str):
|
|
if model in MODEL_MAP.keys():
|
|
self.model = model
|
|
else:
|
|
self.model = "nous-mixtral-8x7b"
|
|
self.model_fullname = MODEL_MAP[self.model]
|
|
self.message_outputer = OpenaiStreamOutputer(model=self.model)
|
|
|
|
def parse_line(self, line):
|
|
line = line.decode("utf-8")
|
|
line = re.sub(r"data:\s*", "", line)
|
|
data = json.loads(line)
|
|
content = ""
|
|
try:
|
|
content = data["token"]["text"]
|
|
except:
|
|
logger.err(data)
|
|
return content
|
|
|
|
def chat_response(
|
|
self,
|
|
prompt: str = None,
|
|
temperature: float = 0.5,
|
|
top_p: float = 0.95,
|
|
max_new_tokens: int = None,
|
|
api_key: str = None,
|
|
use_cache: bool = False,
|
|
):
|
|
|
|
|
|
self.request_url = (
|
|
f"https://api-inference.huggingface.co/models/{self.model_fullname}"
|
|
)
|
|
self.request_headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
if api_key:
|
|
logger.note(
|
|
f"Using API Key: {api_key[:3]}{(len(api_key)-7)*'*'}{api_key[-4:]}"
|
|
)
|
|
self.request_headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
if temperature is None or temperature < 0:
|
|
temperature = 0.0
|
|
|
|
temperature = max(temperature, 0.01)
|
|
temperature = min(temperature, 0.99)
|
|
top_p = max(top_p, 0.01)
|
|
top_p = min(top_p, 0.99)
|
|
|
|
checker = TokenChecker(input_str=prompt, model=self.model)
|
|
|
|
if max_new_tokens is None or max_new_tokens <= 0:
|
|
max_new_tokens = checker.get_token_redundancy()
|
|
else:
|
|
max_new_tokens = min(max_new_tokens, checker.get_token_redundancy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.request_body = {
|
|
"inputs": prompt,
|
|
"parameters": {
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
"return_full_text": False,
|
|
},
|
|
"options": {
|
|
"use_cache": use_cache,
|
|
},
|
|
"stream": True,
|
|
}
|
|
|
|
if self.model in STOP_SEQUENCES_MAP.keys():
|
|
self.stop_sequences = STOP_SEQUENCES_MAP[self.model]
|
|
|
|
|
|
|
|
|
|
logger.back(self.request_url)
|
|
stream_response = requests.post(
|
|
self.request_url,
|
|
headers=self.request_headers,
|
|
json=self.request_body,
|
|
proxies=PROXIES,
|
|
stream=True,
|
|
)
|
|
status_code = stream_response.status_code
|
|
if status_code == 200:
|
|
logger.success(status_code)
|
|
else:
|
|
logger.err(status_code)
|
|
|
|
return stream_response
|
|
|
|
def chat_return_dict(self, stream_response):
|
|
|
|
final_output = self.message_outputer.default_data.copy()
|
|
final_output["choices"] = [
|
|
{
|
|
"index": 0,
|
|
"finish_reason": "stop",
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
},
|
|
}
|
|
]
|
|
logger.back(final_output)
|
|
|
|
final_content = ""
|
|
for line in stream_response.iter_lines():
|
|
if not line:
|
|
continue
|
|
content = self.parse_line(line)
|
|
|
|
if content.strip() == self.stop_sequences:
|
|
logger.success("\n[Finished]")
|
|
break
|
|
else:
|
|
logger.back(content, end="")
|
|
final_content += content
|
|
|
|
if self.model in STOP_SEQUENCES_MAP.keys():
|
|
final_content = final_content.replace(self.stop_sequences, "")
|
|
|
|
final_content = final_content.strip()
|
|
final_output["choices"][0]["message"]["content"] = final_content
|
|
return final_output
|
|
|
|
def chat_return_generator(self, stream_response):
|
|
is_finished = False
|
|
line_count = 0
|
|
for line in stream_response.iter_lines():
|
|
if line:
|
|
line_count += 1
|
|
else:
|
|
continue
|
|
|
|
content = self.parse_line(line)
|
|
|
|
if content.strip() == self.stop_sequences:
|
|
content_type = "Finished"
|
|
logger.success("\n[Finished]")
|
|
is_finished = True
|
|
else:
|
|
content_type = "Completions"
|
|
if line_count == 1:
|
|
content = content.lstrip()
|
|
logger.back(content, end="")
|
|
|
|
output = self.message_outputer.output(
|
|
content=content, content_type=content_type
|
|
)
|
|
yield output
|
|
|
|
if not is_finished:
|
|
yield self.message_outputer.output(content="", content_type="Finished")
|
|
|