|
import spaces |
|
import bm25s |
|
from bm25s.hf import BM25HF, TokenizerHF |
|
import gradio as gr |
|
import json |
|
import Stemmer |
|
import time |
|
import torch |
|
import os |
|
from transformers import AutoTokenizer, AutoModel, pipeline , AutoModelForSequenceClassification, AutoModelForCausalLM |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
from datasets import concatenate_datasets, load_dataset, load_from_disk |
|
from huggingface_hub import hf_hub_download |
|
from contextual import ContextualAI |
|
from openai import AzureOpenAI |
|
from datetime import datetime |
|
import sys |
|
from datetime import datetime |
|
from pathlib import Path |
|
from uuid import uuid4 |
|
import pickle |
|
from huggingface_hub import CommitScheduler |
|
from ast import literal_eval |
|
import re |
|
import requests |
|
|
|
|
|
|
|
def run_courtlistener_api(casename, citation, court): |
|
|
|
params = {"q": casename} |
|
url = "https://www.courtlistener.com/api/rest/v4/search/" |
|
response = requests.get(url, params=params) |
|
|
|
if response.status_code == 200: |
|
print (response.json()["results"]) |
|
result = response.json()["results"][0] |
|
new_url = "https://www.courtlistener.com" + result["absolute_url"] |
|
return f"[Click to see opinion on CourtListener]({new_url})" |
|
else: |
|
return -1 |
|
|
|
|
|
|
|
JSON_DATASET_DIR = Path("json_dataset") |
|
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" |
|
|
|
scheduler = CommitScheduler( |
|
repo_id="ai-law-society-lab/federal-queries-save-dataset", |
|
repo_type="dataset", |
|
folder_path=JSON_DATASET_DIR, |
|
path_in_repo="data", token=os.getenv('hf_token') |
|
) |
|
|
|
|
|
sandbox_api_key=os.getenv('AI_SANDBOX_KEY') |
|
sandbox_endpoint="https://api-ai-sandbox.princeton.edu/" |
|
sandbox_api_version="2024-02-01" |
|
|
|
def text_prompt_call(model_to_be_used, system_prompt, user_prompt ): |
|
client_gpt = AzureOpenAI( |
|
api_key=sandbox_api_key, |
|
azure_endpoint = sandbox_endpoint, |
|
api_version=sandbox_api_version |
|
) |
|
response = client_gpt.chat.completions.create( |
|
model=model_to_be_used, |
|
temperature=0.0, |
|
max_tokens=1000, |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt}, |
|
] |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def format_metadata_as_str(metadata): |
|
try: |
|
out = metadata["case_name"] + ", " + metadata["court_short_name"] + ", " + metadata["date_filed"] |
|
except: |
|
out = "" |
|
return out |
|
|
|
def show_user_query(user_message, history): |
|
''' |
|
Displays user query in the chatbot and removes from textbox. |
|
:param user_message: user query inputted. |
|
:param history: 2D array representing chatbot-user conversation. |
|
:return: |
|
''' |
|
return "", history + [[user_message, None]] |
|
|
|
|
|
def format_metadata_for_reranking(metadata, text, idx): |
|
|
|
|
|
keys = [["court_short_name", "court"], ["date_filed", "year"], ["citation_count", "citation count"]] |
|
out_str = [] |
|
out_str = ["<id>" + str(idx) + "</id>"] |
|
for key in keys: |
|
i,j = key |
|
out_str.append("<" + j + ">" + str(metadata[i]) + "</" + j + ">") |
|
out_str.append("<paragraph>" + " ".join(text.split()) + "</paragraph>") |
|
return "\n".join(out_str) + "\n" |
|
|
|
|
|
def run_extractive_qa(query, contexts): |
|
extracted_passages = extractive_qa([{"question": query, "context": context} for context in contexts]) |
|
return extracted_passages |
|
|
|
|
|
@spaces.GPU(duration=15) |
|
def respond_user_query(history): |
|
''' |
|
Overwrite the value of current pairing's history with generated text |
|
and displays response character-by-character with some lag. |
|
:param history: 2D array of chatbot history filled with user-bot interactions |
|
:return: history updated with bot's latest message. |
|
''' |
|
start_time_global = time.time() |
|
|
|
query = history[0][0] |
|
start_time_global = time.time() |
|
|
|
responses = run_retrieval(query) |
|
print("--- run retrieval: %s seconds ---" % (time.time() - start_time_global)) |
|
|
|
|
|
contexts = [individual_response["text"] for individual_response in responses][:NUM_RESULTS] |
|
extracted_passages = run_extractive_qa(query, contexts) |
|
|
|
for individual_response, extracted_passage in zip(responses, extracted_passages): |
|
start, end = extracted_passage["start"], extracted_passage["end"] |
|
|
|
text = individual_response["text"] |
|
text = text[:start] + " **" + text[start:end] + "** " + text[end:] |
|
|
|
|
|
formatted_response = "##### " |
|
if individual_response["meta_data"]: |
|
formatted_response += individual_response["meta_data"] |
|
else: |
|
formatted_response += individual_response["opinion_idx"] |
|
|
|
casename = individual_response["metadata_reranking"]["case_name"] |
|
citation = " ".join(individual_response["metadata_reranking"]["citations"]) |
|
court = individual_response["metadata_reranking"]["court_short_name"] |
|
|
|
hyperlink = run_courtlistener_api(casename, citation, court) |
|
if hyperlink != -1: |
|
formatted_response += "\n" + hyperlink + "\n" |
|
formatted_response += "\n" + text + "\n\n" |
|
history = history + [[None, formatted_response]] |
|
print("--- Extractive QA: %s seconds ---" % (time.time() - start_time_global)) |
|
|
|
return [history, responses] |
|
|
|
def switch_to_reviewing_framework(): |
|
''' |
|
Replaces textbox for entering user query with annotator review select. |
|
:return: updated visibility for textbox and radio button props. |
|
''' |
|
return gr.Textbox(visible=False), gr.Dataset(visible=False), gr.Textbox(visible=True, interactive=True), gr.Button(visible=True) |
|
|
|
def reset_interface(): |
|
''' |
|
Resets chatbot interface to original position where chatbot history, |
|
reviewing is invisbile is empty and user input textbox is visible. |
|
:return: textbox visibility, review radio button invisibility, |
|
next_button invisibility, empty chatbot |
|
''' |
|
|
|
|
|
|
|
|
|
return gr.Textbox(visible=True), gr.Button(visible=False), gr.Textbox(visible=False, value=""), None, gr.JSON(visible=False, value=[]), gr.Dataset(visible=True) |
|
|
|
|
|
def mark_like(response_json, like_data: gr.LikeData): |
|
index_of_msg_reviewed = like_data.index[0] - 1 |
|
|
|
response_json[index_of_msg_reviewed]["is_msg_liked"] = like_data.liked |
|
return response_json |
|
|
|
""" |
|
def save_json(name: str, greetings: str) -> None: |
|
|
|
""" |
|
def register_review(history, additional_feedback, response_json): |
|
''' |
|
Writes user review to output file. |
|
:param history: 2D array representing bot-user conversation so far. |
|
:return: None, writes to output file. |
|
''' |
|
|
|
res = { "user_query": history[0][0], |
|
"responses": response_json, |
|
"timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
|
"additional_feedback": additional_feedback |
|
} |
|
|
|
|
|
|
|
with scheduler.lock: |
|
with JSON_DATASET_PATH.open("a") as f: |
|
json.dump(res, f) |
|
f.write("\n") |
|
|
|
|
|
|
|
|
|
|
|
def load_bm25(): |
|
""" |
|
stemmer = Stemmer.Stemmer("english") |
|
retriever = bm25s.BM25.load("NJ_index_LLM_chunking", mmap=False) |
|
return retriever, stemmer # titles |
|
""" |
|
retriever = BM25HF.load_from_hub("ai-law-society-lab/bm25s-federal-index", token=os.getenv('hf_token')) |
|
stemmer = Stemmer.Stemmer("english") |
|
splitter = r"\b[\w()/:-]+\b" |
|
bm25_tokenizer = TokenizerHF(stemmer=stemmer, splitter=splitter, lower=True) |
|
bm25_tokenizer.load_vocab_from_hub("ai-law-society-lab/bm25s-federal-index", token=os.getenv('hf_token')) |
|
return retriever, bm25_tokenizer |
|
|
|
def run_bm25(query): |
|
query_tokens = bm25_tokenizer.tokenize(query) |
|
results, scores = retriever.retrieve(query_tokens, k=5) |
|
return results[0] |
|
|
|
def load_faiss_index(embeddings): |
|
nb, d = embeddings.shape |
|
faiss_index = faiss.IndexFlatL2(d) |
|
faiss_index.add(embeddings) |
|
return faiss_index |
|
|
|
|
|
def run_dense_retrieval(query): |
|
if "NV" in model_name: |
|
query_prefix = "Instruct: Given a question, retrieve passages that answer the question\nQuery: " |
|
max_length = 32768 |
|
print (query) |
|
with torch.no_grad(): |
|
query_embeddings = model.encode([query], instruction=query_prefix, max_length=max_length) |
|
query_embeddings = F.normalize(query_embeddings, p=2, dim=1) |
|
query_embeddings = query_embeddings.cpu().numpy() |
|
return query_embeddings |
|
|
|
|
|
|
|
def rerank_with_chatGPT(query, search_results): |
|
search_results_as_dict = {str(i["index"]):i for i in search_results} |
|
|
|
system_prompt = """You are given a list of search results for a query. Rerank the search results such that the paragraphs answering the query in the most comprehensive way are listed first. Additionaly, prioritize reranking in the following order: |
|
1. prioritize metadata according to the query. |
|
2. If the query doesn't ask for specific metadata, prioritize paragraphs from higher courts (Supreme Court first, Circuit courts next, district courts last) |
|
3. Prioritize paragraphs which have higher citation counts. |
|
4. Prioritize parapgrahs from more recent opinions. |
|
Return a python list with the ids of the five highest ranking results, nothing else. |
|
<query>""" + query + "</query>\n\n" |
|
user_prompt = [] |
|
for i in search_results[:50]: |
|
user_prompt.append(format_metadata_for_reranking(i["metadata_reranking"], i["text"], i["index"])) |
|
user_prompt = "\n".join(user_prompt) |
|
out = text_prompt_call("gpt-4o", system_prompt, user_prompt) |
|
print ("OUT", out) |
|
try: |
|
out = literal_eval(re.findall(r"\[.*?\]", out)[0]) |
|
out_dict = [search_results_as_dict[str(i)] for i in out] |
|
print ("SUCCESS") |
|
except Exception as e: |
|
print (e) |
|
out_dict = search_results[:5] |
|
print (out_dict) |
|
return out_dict |
|
|
|
|
|
|
|
|
|
def run_retrieval(query): |
|
query = " ".join(query.split()) |
|
print ("query", query) |
|
|
|
indices_bm25 = run_bm25(query) |
|
|
|
query_embeddings = run_dense_retrieval(query) |
|
|
|
D, I = faiss_index.search(query_embeddings, 35) |
|
scores_embeddings = list(D[0]) |
|
indices_embeddings = I[0] |
|
indices_embeddings = [int(i) for i in indices_embeddings] |
|
|
|
for i in indices_bm25: |
|
if i not in indices_embeddings: |
|
indices_embeddings.append(int(i)) |
|
scores_embeddings.append(-100) |
|
|
|
|
|
|
|
|
|
results = [{"index":i, "NV_score":j, "text":ds_paragraphs[i]["paragraph"]} for i,j in zip(indices_embeddings, scores_embeddings)] |
|
|
|
out_dict = [] |
|
covered = set() |
|
for item in results: |
|
index = item["index"] |
|
item["query"] = query |
|
item["opinion_idx"] = str(ds_paragraphs[index]["idx"]) |
|
|
|
if item["opinion_idx"] in covered: |
|
continue |
|
covered.add(item["opinion_idx"]) |
|
|
|
if item["opinion_idx"] in metadata: |
|
item["meta_data"] = format_metadata_as_str(metadata[item["opinion_idx"]]) |
|
else: |
|
item["meta_data"] = "" |
|
if item["opinion_idx"] in metadata: |
|
item["metadata_reranking"] = metadata[item["opinion_idx"]] |
|
else: |
|
item["metadata_reranking"] = "" |
|
out_dict.append(item) |
|
print ("out_dict_before_reranking") |
|
|
|
res = {"result_type":"chatgpt_reranking"} |
|
res["query"] = query |
|
res["input_reranking"] = [int(i["index"]) for i in out_dict] |
|
res["scores_input_reranking"] = [float(i["NV_score"]) for i in out_dict] |
|
out_dict = rerank_with_chatGPT(query, out_dict)[:NUM_RESULTS] |
|
|
|
res["output_reranking"] = [int(i["index"]) for i in out_dict] |
|
res["scores_output_reranking"] = [float(i["NV_score"]) for i in out_dict] |
|
print (res) |
|
|
|
|
|
with scheduler.lock: |
|
with JSON_DATASET_PATH.open("a") as f: |
|
json.dump(res, f) |
|
f.write("\n") |
|
|
|
print ("RETURNING OUT DICT") |
|
return out_dict |
|
|
|
NUM_RESULTS = 5 |
|
model_name = 'nvidia/NV-Embed-v2' |
|
|
|
|
|
device = torch.device("mps") |
|
|
|
extractive_qa = pipeline("question-answering", model="ai-law-society-lab/extractive-qa-model", tokenizer="FacebookAI/roberta-large", device_map="auto", token=os.getenv('hf_token')) |
|
|
|
ds_paragraphs = load_dataset("ai-law-society-lab/federal-caselaw-paragraphs", token=os.getenv('hf_token'))["train"] |
|
|
|
""" |
|
ds = load_dataset("ai-law-society-lab/federal-caselaw-embeddings-PCA-768", token=os.getenv('hf_token'))["train"] |
|
ds = ds.with_format("np") |
|
faiss_index = load_faiss_index(ds["embeddings"]) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
repo_id = "ai-law-society-lab/autofaiss-federal-index" |
|
file_path = hf_hub_download(repo_id=repo_id, filename="knn.index", repo_type="dataset", token=os.getenv('hf_token')) |
|
faiss_index = faiss.read_index(file_path) |
|
|
|
retriever, bm25_tokenizer = load_bm25() |
|
|
|
|
|
""" |
|
with open('PCA_model.pkl', 'rb') as f: |
|
pca_model = pickle.load(f) |
|
""" |
|
|
|
with open("Federal_caselaw_metadata.json") as f: |
|
metadata = json.load(f) |
|
|
|
|
|
def load_embeddings_model(model_name = "intfloat/e5-large-v2"): |
|
if "NV" in model_name: |
|
|
|
model = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True, torch_dtype=torch.float16, device_map="auto") |
|
model.eval() |
|
return model |
|
|
|
if "NV" in model_name: |
|
model = load_embeddings_model(model_name=model_name) |
|
|
|
|
|
examples = ["Can officers always order a passenger out of a car?"] |
|
|
|
|
|
|
|
css = """ |
|
.svelte-i3tvor {visibility: hidden} |
|
.row.svelte-hrj4a0.unequal-height { |
|
align-items: stretch !important |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css, theme = gr.themes.Monochrome(primary_hue="pink",)) as demo: |
|
chatbot = gr.Chatbot(height="45vw", autoscroll=False) |
|
query_textbox = gr.Textbox() |
|
|
|
examples = gr.Examples(examples, query_textbox) |
|
response_json = gr.JSON(visible=False, value=[]) |
|
print (response_json) |
|
chatbot.like(mark_like, response_json, response_json) |
|
|
|
feedback_textbox = gr.Textbox(label="Additional feedback?", visible=False) |
|
next_button = gr.Button(value="Submit Feedback", visible=False) |
|
|
|
query_textbox.submit(show_user_query, [query_textbox, chatbot], [query_textbox, chatbot], queue=False).then( |
|
respond_user_query, chatbot, [chatbot, response_json]).then( |
|
switch_to_reviewing_framework, None, [query_textbox, examples.dataset, feedback_textbox, next_button] |
|
) |
|
|
|
|
|
next_button.click(register_review, [chatbot, feedback_textbox, response_json], None).then( |
|
reset_interface, None, [query_textbox, next_button, feedback_textbox, chatbot, response_json, examples.dataset]) |
|
|
|
|
|
demo.launch() |