Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Request | |
| from pydantic import BaseModel | |
| from pathlib import Path | |
| from fastapi import Form | |
| from fastapi.responses import JSONResponse | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from PyPDF2 import PdfReader | |
| from fastapi import Depends | |
| #在FastAPI中,Depends()函数用于声明依赖项 | |
| from huggingface_hub import InferenceClient | |
| import numpy as np | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain import PromptTemplate, LLMChain | |
| from langchain import HuggingFaceHub | |
| from langchain.document_loaders import TextLoader | |
| import torch | |
| from sentence_transformers.util import semantic_search | |
| import requests | |
| import random | |
| import string | |
| import sys | |
| import timeit | |
| import datetime | |
| import io | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
| model_id = os.getenv('model_id') | |
| hf_token = os.getenv('hf_token') | |
| repo_id = os.getenv('repo_id') | |
| def get_embeddings(input_str_texts): | |
| response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}}) | |
| return response.json() | |
| def generate_random_string(length): | |
| letters = string.ascii_lowercase | |
| return ''.join(random.choice(letters) for i in range(length)) | |
| def remove_context(text): | |
| if 'Context:' in text: | |
| end_of_context = text.find('\n\n') | |
| return text[end_of_context + 2:] | |
| else: | |
| return text | |
| api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| llm = HuggingFaceHub(repo_id=repo_id, | |
| model_kwargs={"min_length":512, | |
| "max_new_tokens":1024, "do_sample":True, | |
| "temperature":0.01, | |
| "top_k":50, | |
| "top_p":0.95, "eos_token_id":49155}) | |
| #prompt_template = """ | |
| #You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer. | |
| #Your response should be full and easy to understand. | |
| #""" | |
| #PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
| #chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT) | |
| chain = load_qa_chain(llm=llm, chain_type="stuff") | |
| app = FastAPI() | |
| class FileToProcess(BaseModel): | |
| uploaded_file: UploadFile = File(...) | |
| async def home(): | |
| return "API Working!" | |
| #async def upload_file(user_question: str, file_to_process: FileToProcess = Depends()): | |
| async def pdf_file_qa_process(user_question: str, request: Request, file_to_process: FileToProcess = Depends()): | |
| print("API Call Triggered.") | |
| start_0 = timeit.default_timer() | |
| uploaded_file = file_to_process.uploaded_file | |
| print("File received:"+uploaded_file.filename) | |
| user_question = request.query_params.get("user_question") | |
| filename = request.query_params.get("filename") | |
| print("User entered question: "+user_question) | |
| print("User uploaded file: "+filename) | |
| random_string = generate_random_string(20) | |
| file_path = Path.cwd() / random_string | |
| file_path.mkdir(parents=True, exist_ok=True) | |
| file_saved_in_api = file_path / uploaded_file.filename | |
| print(file_saved_in_api) | |
| with open(file_saved_in_api, "wb+") as file_object: | |
| file_object.write(uploaded_file.file.read()) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| #separator = "\n", | |
| chunk_size = 500, | |
| chunk_overlap = 100, #striding over the text | |
| length_function = len, | |
| ) | |
| doc_reader = PdfReader(file_saved_in_api) | |
| raw_text = '' | |
| for i, page in enumerate(doc_reader.pages): | |
| text = page.extract_text() | |
| if text: | |
| raw_text += text | |
| temp_texts = text_splitter.split_text(raw_text) | |
| texts=temp_texts | |
| initial_embeddings=get_embeddings(temp_texts) | |
| db_embeddings = torch.FloatTensor(initial_embeddings) | |
| print(db_embeddings) | |
| print("db_embeddings created...") | |
| #question = var_query.query | |
| question = user_question | |
| print("API Call Query Received: "+question) | |
| q_embedding=get_embeddings(question) | |
| final_q_embedding = torch.FloatTensor(q_embedding) | |
| print(final_q_embedding) | |
| print("Semantic Similarity Search Starts...") | |
| start_1 = timeit.default_timer() | |
| hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5) | |
| end_1 = timeit.default_timer() | |
| print("Semantic Similarity Search Ends...") | |
| print(f'Semantic Similarity Search共耗时: @ {end_1 - start_1}') | |
| page_contents = [] | |
| for i in range(len(hits[0])): | |
| page_content = texts[hits[0][i]['corpus_id']] | |
| page_contents.append(page_content) | |
| print(page_contents) | |
| temp_page_contents=str(page_contents) | |
| final_page_contents = temp_page_contents.replace('\\n', '') | |
| random_string_2=generate_random_string(20) | |
| file_path = random_string_2 + ".txt" | |
| with open(file_path, "w", encoding="utf-8") as file: | |
| file.write(final_page_contents) | |
| loader = TextLoader(file_path, encoding="utf-8") | |
| loaded_documents = loader.load() | |
| print("*****loaded_documents******") | |
| print(loaded_documents) | |
| print("***********") | |
| print(question) | |
| print("*****question******") | |
| print("LLM Chain Starts...") | |
| start_2 = timeit.default_timer() | |
| temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False) | |
| end_2 = timeit.default_timer() | |
| print("LLM Chain Ends...") | |
| print(f'LLM Chain共耗时: @ {end_2 - start_2}') | |
| print(temp_ai_response) | |
| initial_ai_response=temp_ai_response['output_text'] | |
| print(initial_ai_response) | |
| cleaned_initial_ai_response = remove_context(initial_ai_response) | |
| #final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '') | |
| final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip() | |
| final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip() | |
| final_ai_response = final_ai_response.partition('¿Qué es')[0].strip() | |
| final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '') | |
| new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip() | |
| new_final_ai_response = new_final_ai_response.split('Note:')[0].strip() | |
| new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip() | |
| print(new_final_ai_response) | |
| end_0 = timeit.default_timer() | |
| print("API Call Ended.") | |
| print(f'API Call共耗时: @ {end_0 - start_0}') | |
| return {"AIResponse": new_final_ai_response} | |
| #return JSONResponse({"AIResponse": new_final_ai_response}) |