|
import gradio as gr |
|
import os |
|
import json |
|
import uuid |
|
import torch |
|
import datetime |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
from huggingface_hub import HfApi, create_repo, hf_hub_download |
|
from torchcrf import CRF |
|
|
|
HF_DATASET_REPO = "M2ai/mgtd-logs" |
|
HF_TOKEN = os.getenv("Mgtd") |
|
DATASET_CREATED = False |
|
|
|
code = "ENG" |
|
pntr = 2 |
|
model_name_or_path = "microsoft/mdeberta-v3-base" |
|
hf_token = os.environ.get("Mgtd") |
|
|
|
file_path = hf_hub_download(repo_id="1024m/MGTD-Long-New",filename=f"{code}/mdeberta-epoch-{pntr}.pt",token=hf_token,local_dir="./checkpoints") |
|
|
|
def setup_hf_dataset(): |
|
global DATASET_CREATED |
|
if not DATASET_CREATED and HF_TOKEN: |
|
try: |
|
create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True) |
|
DATASET_CREATED = True |
|
print(f"Dataset {HF_DATASET_REPO} is ready.") |
|
except Exception as e: |
|
print(f"Error setting up dataset: {e}") |
|
|
|
class AutoModelCRF(nn.Module): |
|
def __init__(self, model_name_or_path, dropout=0.075): |
|
super().__init__() |
|
self.config = AutoConfig.from_pretrained(model_name_or_path) |
|
self.num_labels = 2 |
|
self.encoder = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, config=self.config) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear = nn.Linear(self.config.hidden_size, self.num_labels) |
|
self.crf = CRF(self.num_labels, batch_first=True) |
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
seq_output = self.dropout(outputs[0]) |
|
emissions = self.linear(seq_output) |
|
tags = self.crf.decode(emissions, attention_mask.byte()) |
|
return tags, emissions |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
model = AutoModelCRF(model_name_or_path) |
|
checkpoint = torch.load(file_path, map_location="cpu") |
|
model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=False) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
def get_color(prob): |
|
if prob < 0.25: |
|
return "green" |
|
elif prob < 0.5: |
|
return "yellow" |
|
elif prob < 0.75: |
|
return "orange" |
|
else: |
|
return "red" |
|
|
|
def get_word_probabilities(text): |
|
text = " ".join(text.split(" ")[:2048]) |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
with torch.no_grad(): |
|
tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) |
|
probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy() |
|
word_probs = [] |
|
word_colors = [] |
|
current_word = "" |
|
current_probs = [] |
|
for token, prob in zip(tokens, probs): |
|
if token in ["<s>", "</s>"]: |
|
continue |
|
if token.startswith("β"): |
|
if current_word and current_probs: |
|
current_prob = sum(current_probs) / len(current_probs) |
|
word_probs.append(current_prob) |
|
color = get_color(current_prob) |
|
word_colors.append(color) |
|
current_word = token[1:] if token != "β" else "" |
|
current_probs = [prob] |
|
else: |
|
current_word += token |
|
current_probs.append(prob) |
|
if current_word and current_probs: |
|
current_prob = sum(current_probs) / len(current_probs) |
|
word_probs.append(current_prob) |
|
color = get_color(current_prob) |
|
word_colors.append(color) |
|
|
|
|
|
first_avg = (word_probs[1] + word_probs[2]) / 2 |
|
word_colors[0] = get_color(first_avg) |
|
|
|
last_avg = (word_probs[-2] + word_probs[-3]) / 2 |
|
word_colors[-1] = get_color(last_avg) |
|
|
|
|
|
word_probs = [float(p) for p in word_probs] |
|
return word_probs, word_colors |
|
|
|
def infer_and_log(text_input): |
|
word_probs, word_colors = get_word_probabilities(text_input) |
|
timestamp = datetime.datetime.now().isoformat() |
|
submission_id = str(uuid.uuid4()) |
|
log_data = {"id": submission_id,"timestamp": timestamp,"input": text_input,"output_probs": word_probs} |
|
os.makedirs("logs", exist_ok=True) |
|
log_file = f"logs/{timestamp.replace(':', '_')}.json" |
|
with open(log_file, "w") as f: |
|
json.dump(log_data, f, indent=2) |
|
if HF_TOKEN and DATASET_CREATED: |
|
try: |
|
HfApi().upload_file(path_or_fileobj=log_file,path_in_repo=f"logs/{os.path.basename(log_file)}",repo_id=HF_DATASET_REPO,repo_type="dataset",token=HF_TOKEN) |
|
print(f"Uploaded log {submission_id}") |
|
except Exception as e: |
|
print(f"Error uploading log: {e}") |
|
tokens = text_input.split() |
|
formatted_output = " ".join(f'<span style= "color:{color}">{token}</span>' for token, color in zip(tokens, word_colors)) |
|
return formatted_output, word_probs |
|
|
|
def clear_fields(): |
|
return "", "", {} |
|
setup_hf_dataset() |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("Machine Generated Text Detector") |
|
with gr.Row(): |
|
input_box = gr.Textbox(label="Input Text", lines=10) |
|
output_html = gr.HTML(label="Color-Coded Output") |
|
output_json = gr.JSON(label="Word Probabilities",visible=False) |
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit") |
|
clear_btn = gr.Button("Clear") |
|
submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=[output_html, output_json]) |
|
clear_btn.click(fn=clear_fields, outputs=[input_box, output_html, output_json]) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|