import gradio as gr
import os
import json
import uuid
import torch
import datetime
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, AutoConfig
from huggingface_hub import HfApi, create_repo, hf_hub_download
from torchcrf import CRF
# Constants
HF_DATASET_REPO = "M2ai/mgtd-logs"
HF_TOKEN = os.getenv("Mgtd")
DATASET_CREATED = False
# Model identifiers
code = "ENG"
pntr = 2
model_name_or_path = "microsoft/mdeberta-v3-base"
hf_token = os.environ.get("Mgtd")
# Download model checkpoint
file_path = hf_hub_download(repo_id="1024m/MGTD-Long-New",filename=f"{code}/mdeberta-epoch-{pntr}.pt",token=hf_token,local_dir="./checkpoints")
def setup_hf_dataset():
global DATASET_CREATED
if not DATASET_CREATED and HF_TOKEN:
try:
create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
DATASET_CREATED = True
print(f"Dataset {HF_DATASET_REPO} is ready.")
except Exception as e:
print(f"Error setting up dataset: {e}")
class AutoModelCRF(nn.Module):
def __init__(self, model_name_or_path, dropout=0.075):
super().__init__()
self.config = AutoConfig.from_pretrained(model_name_or_path)
self.num_labels = 2
self.encoder = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, config=self.config)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(self.config.hidden_size, self.num_labels)
self.crf = CRF(self.num_labels, batch_first=True)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
seq_output = self.dropout(outputs[0])
emissions = self.linear(seq_output)
tags = self.crf.decode(emissions, attention_mask.byte())
return tags, emissions
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelCRF(model_name_or_path)
checkpoint = torch.load(file_path, map_location="cpu")
model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=False)
model = model.to(device)
model.eval()
def get_color(prob):
if prob < 0.25:
return "green"
elif prob < 0.5:
return "yellow"
elif prob < 0.75:
return "orange"
else:
return "red"
def get_word_probabilities(text):
text = " ".join(text.split(" ")[:2048])
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
with torch.no_grad():
tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
word_probs = []
word_colors = []
current_word = ""
current_probs = []
for token, prob in zip(tokens, probs):
if token in ["", ""]:
continue
if token.startswith("▁"):
if current_word and current_probs:
current_prob = sum(current_probs) / len(current_probs)
word_probs.append(current_prob)
color = get_color(current_prob)
word_colors.append(color)
current_word = token[1:] if token != "▁" else ""
current_probs = [prob]
else:
current_word += token
current_probs.append(prob)
if current_word and current_probs:
current_prob = sum(current_probs) / len(current_probs)
word_probs.append(current_prob)
color = get_color(current_prob)
word_colors.append(color)
####### FOR STABLE OUTPUTS
first_avg = (word_probs[1] + word_probs[2]) / 2
word_colors[0] = get_color(first_avg)
last_avg = (word_probs[-2] + word_probs[-3]) / 2
word_colors[-1] = get_color(last_avg)
#########
word_probs = [float(p) for p in word_probs]
return word_probs, word_colors
def infer_and_log(text_input):
word_probs, word_colors = get_word_probabilities(text_input)
timestamp = datetime.datetime.now().isoformat()
submission_id = str(uuid.uuid4())
log_data = {"id": submission_id,"timestamp": timestamp,"input": text_input,"output_probs": word_probs}
os.makedirs("logs", exist_ok=True)
log_file = f"logs/{timestamp.replace(':', '_')}.json"
with open(log_file, "w") as f:
json.dump(log_data, f, indent=2)
if HF_TOKEN and DATASET_CREATED:
try:
HfApi().upload_file(path_or_fileobj=log_file,path_in_repo=f"logs/{os.path.basename(log_file)}",repo_id=HF_DATASET_REPO,repo_type="dataset",token=HF_TOKEN)
print(f"Uploaded log {submission_id}")
except Exception as e:
print(f"Error uploading log: {e}")
tokens = text_input.split()
formatted_output = " ".join(f'{token}' for token, color in zip(tokens, word_colors))
return formatted_output, word_probs
def clear_fields():
return "", "", {}
setup_hf_dataset()
with gr.Blocks() as app:
gr.Markdown("Machine Generated Text Detector")
with gr.Row():
input_box = gr.Textbox(label="Input Text", lines=10)
output_html = gr.HTML(label="Color-Coded Output")
output_json = gr.JSON(label="Word Probabilities",visible=False)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=[output_html, output_json])
clear_btn.click(fn=clear_fields, outputs=[input_box, output_html, output_json])
if __name__ == "__main__":
app.launch()