|
from transformers import Qwen2Config, Qwen2ForCausalLM |
|
import torch |
|
import requests |
|
from bs4 import BeautifulSoup |
|
from duckduckgo_search import DDGS |
|
import logging |
|
import re |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
class CustomQwen2Config(Qwen2Config): |
|
model_type = "custom_qwen2" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
|
config = super().from_dict(config_dict, **kwargs) |
|
return config |
|
|
|
def to_dict(self): |
|
output = super().to_dict() |
|
output["model_type"] = self.model_type |
|
return output |
|
|
|
class CustomQwen2Model(Qwen2ForCausalLM): |
|
config_class = CustomQwen2Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.tokenizer = None |
|
self.embedding_model = None |
|
self.max_iterations = 5 |
|
self.use_search = True |
|
self.top_k = 3 |
|
self.max_search_attempts = 3 |
|
|
|
def set_tokenizer(self, tokenizer=None): |
|
self.tokenizer = tokenizer |
|
|
|
|
|
def set_max_iterations(self, max_iterations): |
|
self.max_iterations = max_iterations |
|
|
|
def set_use_search(self, use_search): |
|
self.use_search = use_search |
|
|
|
def set_top_k(self, top_k): |
|
self.top_k = top_k |
|
|
|
def generate_step(self, input_ids, max_new_tokens=150): |
|
""" |
|
Generates output from input_ids and returns tokenized output. |
|
""" |
|
input_ids = input_ids.to(self.device) |
|
output_ids = super().generate(input_ids, max_new_tokens=max_new_tokens) |
|
return output_ids |
|
|
|
def extract_response(self, output_ids, keyword): |
|
""" |
|
Extracts the tokens following a specific keyword from the generated response. |
|
Returns extracted text. |
|
""" |
|
|
|
raw_response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
pattern = rf"{re.escape(keyword)}\s*(.*)" |
|
match = re.search(pattern, raw_response, re.DOTALL) |
|
|
|
if match: |
|
|
|
extracted_text = match.group(1).strip() |
|
return extracted_text |
|
else: |
|
|
|
return "[ALL]" + raw_response |
|
|
|
def generate(self, input_ids, max_new_tokens=150, **kwargs): |
|
logging.info(f"Maximum keyword regeneration attempts: {self.max_iterations}") |
|
logging.info(f"External URL reference: {'Enabled' if self.use_search else 'Disabled'}") |
|
logging.info(f"k_top value: {self.top_k}") |
|
|
|
org_instruction = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
|
|
keyword_attempt = 0 |
|
sufficient_info = False |
|
summarized_info = "" |
|
|
|
while keyword_attempt < self.max_iterations and not sufficient_info: |
|
logging.info(f"Keyword regeneration attempt: {keyword_attempt + 1}/{self.max_iterations}") |
|
|
|
|
|
if self.use_search: |
|
logging.info("Retrieving relevant information using external URL references...") |
|
for search_attempt in range(1, self.max_search_attempts + 1): |
|
logging.info(f"Search attempt: {search_attempt}/{self.max_search_attempts}") |
|
relevant_docs = self.retrieve_relevant_information(org_instruction, top_k=self.top_k) |
|
summarized_info = self.summarize_documents(relevant_docs, org_instruction) |
|
|
|
|
|
sufficient_info = self.is_answer_sufficient(summarized_info, org_instruction) |
|
if sufficient_info: |
|
logging.info("Sufficient information found.") |
|
break |
|
else: |
|
logging.info("Insufficient information. Attempting next search.") |
|
|
|
if not sufficient_info: |
|
|
|
new_keywords = self.generate_new_keywords(org_instruction) |
|
if new_keywords: |
|
org_instruction = self.update_instruction_with_new_keywords(org_instruction, new_keywords) |
|
logging.info(f"Retrying search with new keywords: {new_keywords}") |
|
else: |
|
logging.warning("Failed to generate new keywords.") |
|
break |
|
|
|
else: |
|
summarized_info = "" |
|
sufficient_info = False |
|
|
|
keyword_attempt += 1 |
|
|
|
if not sufficient_info: |
|
logging.info("Relevant data sources not found. Performing self-reasoning.") |
|
final_response = self.self_reasoning(org_instruction, max_new_tokens) |
|
else: |
|
|
|
final_response = self.generate_answer(org_instruction, summarized_info, max_new_tokens) |
|
|
|
|
|
final_response_ids = self.tokenizer.encode(final_response, return_tensors="pt").to(self.device) |
|
return final_response_ids |
|
|
|
def retrieve_relevant_information(self, user_input, top_k=3): |
|
search_query = self.generate_search_query(user_input) |
|
logging.info(f"Generated search query: {search_query}") |
|
|
|
if not search_query: |
|
logging.warning("Search query is empty.") |
|
return ["No relevant information found."] |
|
|
|
with DDGS() as ddgs: |
|
search_results = ddgs.text( |
|
keywords=search_query, |
|
region='wt-wt', |
|
safesearch='off', |
|
timelimit=None, |
|
max_results=20 |
|
) |
|
search_results = list(search_results) |
|
|
|
if not search_results: |
|
return ["No relevant information found."] |
|
|
|
|
|
documents = [] |
|
for result in search_results: |
|
if 'body' in result and result['body']: |
|
documents.append(result['body']) |
|
elif 'snippet' in result and result['snippet']: |
|
documents.append(result['snippet']) |
|
|
|
|
|
documents = documents[:top_k] |
|
return documents |
|
|
|
def generate_search_query(self, user_input): |
|
""" |
|
Generates a search query using the model's inference. |
|
""" |
|
|
|
prompt = f""" |
|
User's question: |
|
{user_input} |
|
|
|
Organize what you need to know to answer this problem and list three keywords to research. |
|
|
|
Keywords: |
|
-""" |
|
|
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
|
|
|
output_ids = self.generate_step(input_ids, max_new_tokens=50) |
|
|
|
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
pattern = r"Keywords:\s*(.*)" |
|
match = re.search(pattern, generated_text, re.DOTALL) |
|
if match: |
|
keywords_text = match.group(1).strip() |
|
|
|
keywords = re.findall(r"-\s*(.*)", keywords_text) |
|
search_query = ' '.join(keywords) |
|
logging.info(f"Generated search query: {search_query}") |
|
return search_query |
|
else: |
|
logging.warning("Failed to generate keywords.") |
|
return "" |
|
|
|
def generate_new_keywords(self, user_input): |
|
""" |
|
Attempts to regenerate keywords. |
|
""" |
|
prompt = f""" |
|
User's question: |
|
{user_input} |
|
|
|
Insufficient information was obtained. Please generate new keywords. |
|
List three new keywords. |
|
|
|
Keywords: |
|
-""" |
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
|
output_ids = self.generate_step(input_ids, max_new_tokens=50) |
|
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
pattern = r"Keywords:\s*(.*)" |
|
match = re.search(pattern, generated_text, re.DOTALL) |
|
if match: |
|
keywords_text = match.group(1).strip() |
|
keywords = re.findall(r"-\s*(.*)", keywords_text) |
|
search_query = ' '.join(keywords) |
|
logging.info(f"Regenerated search query: {search_query}") |
|
return search_query |
|
else: |
|
logging.warning("Failed to extract regenerated keywords.") |
|
return "" |
|
|
|
def update_instruction_with_new_keywords(self, instruction, new_keywords): |
|
""" |
|
Incorporates new keywords into the original instruction. |
|
""" |
|
|
|
updated_instruction = f"{instruction} Keywords: {new_keywords}" |
|
return updated_instruction |
|
|
|
def is_answer_sufficient(self, summarized_info, user_input): |
|
""" |
|
Determines if the summarized information is sufficient to answer the question. |
|
""" |
|
prompt = f""" |
|
User's question: |
|
{user_input} |
|
|
|
Retrieved information: |
|
{summarized_info} |
|
|
|
Based on this information, determine if you can answer the user's question. |
|
If yes, respond with "Yes". If no, respond with "No" only. |
|
""" |
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
|
output_ids = self.generate_step(input_ids, max_new_tokens=10) |
|
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
|
|
|
if "Yes" in generated_text: |
|
return True |
|
else: |
|
return False |
|
|
|
def generate_answer(self, user_input, summarized_info, max_new_tokens=150): |
|
""" |
|
Generates an answer based on the retrieved information. |
|
""" |
|
|
|
step1_prompt = f""" |
|
#User's question: |
|
{user_input} |
|
|
|
#Step 1: Understanding the question and extracting key points |
|
Accurately understand the user's question or instructions. |
|
Output the rules for answering and the tasks to be performed in a bullet list. |
|
|
|
#Rules for answering and tasks to be performed: |
|
""" |
|
step1_input_ids = self.tokenizer.encode(step1_prompt, return_tensors="pt").to(self.device) |
|
outputs_step1 = self.generate_step(step1_input_ids, max_new_tokens=max_new_tokens) |
|
step1_response = self.extract_response(outputs_step1, "#Rules for answering and tasks to be performed:") |
|
logging.info("Understanding the question...\n======================\n" + step1_response) |
|
|
|
|
|
step2_prompt = f""" |
|
#Step 2: Considerations for problem-solving |
|
Based on the content of Step 1, consider approaches and necessary information for solving the problem. |
|
|
|
#Step 2 response: |
|
""" |
|
step2_input_ids = self.tokenizer.encode(step1_response + step2_prompt, return_tensors="pt").to(self.device) |
|
outputs_step2 = self.generate_step(step2_input_ids, max_new_tokens=max_new_tokens) |
|
step2_response = self.extract_response(outputs_step2, "#Step 2 response:") |
|
logging.info("Considering approaches for problem-solving...\n======================\n" + step2_response) |
|
|
|
|
|
step3_prompt = f""" |
|
#Step 3: Creating the initial answer |
|
Based on the content so far, create an initial answer to the user's question. |
|
Your information may not be up-to-date. Fully consider information from the internet. |
|
|
|
#Latest internet information: |
|
{summarized_info} |
|
|
|
#Initial answer: |
|
""" |
|
step3_input_ids = self.tokenizer.encode(step2_response + step3_prompt, return_tensors="pt").to(self.device) |
|
outputs_step3 = self.generate_step(step3_input_ids, max_new_tokens=max_new_tokens) |
|
step3_response = self.extract_response(outputs_step3, "#Initial answer:") |
|
logging.info("Creating the initial answer...\n======================\n" + step3_response) |
|
|
|
|
|
reflection_prompt = f""" |
|
#Step 4: Reflection (Self-verification) |
|
Verify whether the initial answer accurately responds to the user's question or instructions, and point out any errors or areas for improvement. |
|
Be cautious of overinterpreting the instructions and critically assess whether you have accurately understood them. |
|
Your information may not be up-to-date. Fully consider information from the internet. |
|
Reconfirm the user's question and provide an accurate answer to the question itself. (Ensure that you provide an answer to the question itself) |
|
|
|
#User's question: |
|
{user_input} |
|
|
|
#Latest internet information: |
|
{summarized_info} |
|
|
|
#Initial answer: |
|
{step3_response} |
|
|
|
#Reflection result: |
|
""" |
|
reflection_input_ids = self.tokenizer.encode(reflection_prompt, return_tensors="pt").to(self.device) |
|
outputs_reflection = self.generate_step(reflection_input_ids, max_new_tokens=max_new_tokens) |
|
reflection_response = self.extract_response(outputs_reflection, "#Reflection result:") |
|
logging.info("Performing reflection...\n======================\n" + reflection_response) |
|
|
|
|
|
final_prompt = f""" |
|
#Step 5: Creating the final answer |
|
Based on the reflection results, modify the initial answer as needed. |
|
Your knowledge may not be up-to-date. Fully consider information from the internet. |
|
Reconfirm the user's question, and check for overinterpretation, misunderstandings, omissions, and careless mistakes. |
|
Create the final answer incorporating these. |
|
|
|
#Initial answer: |
|
{step3_response} |
|
|
|
#Reflection result: |
|
{reflection_response} |
|
|
|
#Latest internet information: |
|
{summarized_info} |
|
|
|
#User's question: |
|
{user_input} |
|
|
|
Please provide the final answer to the user's question. |
|
#Final answer: |
|
""" |
|
final_input_ids = self.tokenizer.encode(final_prompt, return_tensors="pt").to(self.device) |
|
outputs_final = self.generate_step(final_input_ids, max_new_tokens=max_new_tokens) |
|
final_response = self.extract_response(outputs_final, "#Final answer:").strip() |
|
|
|
return final_response |
|
|
|
def self_reasoning(self, user_input, max_new_tokens=150): |
|
""" |
|
Generates an answer based on self-reasoning. |
|
""" |
|
prompt = f""" |
|
User's question: |
|
{user_input} |
|
|
|
No relevant information was found on the internet. Please use your own knowledge and reasoning to answer. |
|
|
|
#Answer based on self-reasoning: |
|
""" |
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
|
output_ids = self.generate_step(input_ids, max_new_tokens=max_new_tokens) |
|
generated_text = self.extract_response(output_ids, "#Answer based on self-reasoning:").strip() |
|
logging.info("Answer based on self-reasoning:\n======================\n" + generated_text) |
|
return generated_text |
|
|
|
def process_document(self, doc, user_input): |
|
""" |
|
Determines if each document is relevant to the user's question and generates an answer if applicable. |
|
""" |
|
|
|
prompt = f""" |
|
User's question: |
|
{user_input} |
|
|
|
Content of the document: |
|
{doc[:2000]} # Truncate if too long |
|
|
|
Do not think of the question superficially. Use paradoxes and rephrasing to organize. |
|
Create an answer to the question based on the content of this document. |
|
Understand the points of disagreement between your own thoughts and the answer you would create based on this document, and prioritize the answer based on the document. |
|
|
|
Answer: |
|
""" |
|
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) |
|
output_ids = self.generate_step(input_ids, max_new_tokens=500) |
|
generated_text = self.extract_response(output_ids, "Answer:") |
|
logging.info("Document processing result: " + generated_text) |
|
|
|
if "low relevance" in generated_text: |
|
return "" |
|
else: |
|
return generated_text.strip() |
|
|
|
def summarize_documents(self, documents, user_input): |
|
""" |
|
Processes each document and summarizes relevant information. |
|
""" |
|
summaries = [] |
|
for doc in documents: |
|
processed_text = self.process_document(doc, user_input) |
|
if processed_text: |
|
summaries.append(processed_text) |
|
return "\n\n".join(summaries) |
|
|