Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_name = "facebook/opt-350m" | |
| # model_name = "NousResearch/Llama-2-7b-chat-hf" | |
| tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") | |
| model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf") | |
| def predict(message, chatbot, temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,): | |
| system_message = "\nλΉμ μ λμμ΄ λκ³ μ μ€νλ©° μ μ§ν Assistantμ λλ€. μμ μ μ μ§νλ©΄μ νμ κ°λ₯ν ν λμμ΄ λλλ‘ λ΅λ³νμμμ€. κ·νμ λ΅λ³μλ μ ν΄νκ±°λ, λΉμ€λ¦¬μ μ΄κ±°λ, μΈμ’ μ°¨λ³μ μ΄κ±°λ, μ±μ°¨λ³μ μ΄κ±°λ, λ μ±μ΄ μκ±°λ, μννκ±°λ λΆλ²μ μΈ μ½ν μΈ κ° ν¬ν¨λμ΄μλ μ λ©λλ€. κ·νμ λ΅λ³μ μ¬νμ μΌλ‘ νΈκ²¬μ΄ μκ³ κΈμ μ μ λλ€.\n\nμ§λ¬Έμ΄ μλ―Έκ° μκ±°λ μ¬μ€μ μΌλ‘ μΌκ΄μ±μ΄ μλ κ²½μ°, μ³μ§ μμ κ²μ λ΅λ³νλ λμ μ΄μ λ₯Ό μ€λͺ νμμμ€. μ§λ¬Έμ λν λ΅λ³μ λͺ¨λ₯΄λ κ²½μ°, νμμ 보 곡μ νμ§ λ§μΈμ" | |
| input_system = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n " | |
| input_history = "" | |
| for interaction in chatbot: | |
| input_history = input_system + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] " | |
| input_prompt = input_history + str(message) + " [/INST] " | |
| inputs = tokenizer.encode(input_prompt, return_tensors="pt").to('cuda') | |
| temperature = float(temperature) | |
| if temperature < 1e-2: temperature = 1e-2 | |
| top_p = float(top_p) | |
| generate_kwargs = dict( | |
| input_ids=inputs, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| outputs = model.generate(**generate_kwargs) | |
| generated_indcluded_full_text = tokenizer.decode(outputs[0]) | |
| print("generated_indcluded_full_text:", generated_indcluded_full_text) | |
| generated_text = generated_indcluded_full_text.split('[/INST] ')[-1] | |
| if '</s>' in generated_text : | |
| generated_text = generated_text.split('</s>')[0] | |
| else : pass | |
| import json | |
| tokens = generated_text.split('\n') | |
| token_list = [] | |
| for idx, token in enumerate(tokens): | |
| token_dict = {"id": idx + 1, "text": token} | |
| token_list.append(token_dict) | |
| response = {"data": {"token": token_list}} | |
| response = json.dumps(response, indent=4) | |
| response = json.loads(response) | |
| data_dict = response.get('data', {}) | |
| token_list = data_dict.get('token', []) | |
| import time | |
| partial_message = "" | |
| for token_entry in token_list: | |
| if token_entry: | |
| try: | |
| token_id = token_entry.get('id', None) | |
| token_text = token_entry.get('text', None) | |
| if token_text: | |
| for char in token_text: | |
| partial_message += char | |
| yield partial_message | |
| time.sleep(0.01) | |
| else: | |
| gr.Warning(f"The key 'text' does not exist or is None in this token entry: {token_entry}") | |
| except KeyError as e: | |
| gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}") | |
| continue | |
| title = "TheBloke/Llama-2-7b-Chat-GPTQλ λͺ¨λΈ chatbot" | |
| description = """ | |
| TheBloke/Llama-2-7b-Chat-GPTQ λͺ¨λΈμ λλ€. | |
| """ | |
| css = """.toast-wrap { display: none !important } """ | |
| examples=[ | |
| ['Hello there! How are you doing?'], | |
| ['Can you explain to me briefly what is Python programming language?'], | |
| ['Explain the plot of Cinderella in a sentence.'], | |
| ['How many hours does it take a man to eat a Helicopter?'], | |
| ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], | |
| ] | |
| import gradio as gr | |
| def vote(data: gr.LikeData): | |
| if data.liked: | |
| print("You upvoted this response: " + data.value) | |
| else: | |
| print("You downvoted this response: " + data.value) | |
| additional_inputs=[ | |
| gr.Slider( | |
| label="Temperature", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ), | |
| gr.Slider( | |
| label="Max new tokens", | |
| value=256, | |
| minimum=0, | |
| maximum=4096, | |
| step=64, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.6, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| ] | |
| chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'), bubble_full_width = False) | |
| chat_interface_stream = gr.ChatInterface(predict, | |
| title=title, | |
| description=description, | |
| chatbot=chatbot_stream, | |
| css=css, | |
| examples=examples, | |
| cache_examples=False, | |
| additional_inputs=additional_inputs,) | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Streaming"): | |
| chatbot_stream.like(vote, None, None) | |
| chat_interface_stream.render() | |
| demo.queue(concurrency_count=75, max_size=100).launch(debug=True) |