DroidDetect-Base
This is a text classification model based on answerdotai/ModernBERT-base
, fine-tuned to distinguish between human-written, AI-refined, Adversarial and AI-generated code.
The model was trained on the DroidCollection
dataset. It's designed as a 4-class classifier to address the core task of AI code detection.
A key feature of this model is its training objective, which combines standard Cross-Entropy Loss with a Batch-Hard Triplet Loss. This contrastive loss component encourages the model to learn more discriminative embeddings by pushing representations of human vs. machine code further apart in the vector space.
Model Details
- Base Model:
answerdotai/ModernBERT-base
- Loss Function:
Total Loss = CrossEntropyLoss + 0.1 * TripletLoss
- Dataset: Filtered training set of the DroidCollection.
Label Mapping
The model predicts one of 4 classes. The mapping from ID to label is as follows:
{
"0": "HUMAN_GENERATED",
"1": "MACHINE_GENERATED",
"2": "MACHINE_REFINED",
"3": "MACHINE_GENERATED_ADVERSARIAL",
}
Model Code
The following code can be used for reproducibility:
TEXT_EMBEDDING_DIM = 768
class TLModel(nn.Module):
def __init__(self, text_encoder, projection_dim=128, num_classes=NUM_CLASSES, class_weights=None):
super().__init__()
self.text_encoder = text_encoder
self.num_classes = num_classes
text_output_dim = TEXT_EMBEDDING_DIM
self.additional_loss = losses.BatchHardSoftMarginTripletLoss(self.text_encoder)
self.text_projection = nn.Linear(text_output_dim, projection_dim)
self.classifier = nn.Linear(projection_dim, num_classes)
self.class_weights = class_weights
def forward(self, labels=None, input_ids=None, attention_mask=None):
actual_labels = labels
sentence_embeddings = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
sentence_embeddings = sentence_embeddings.mean(dim=1)
projected_text = F.relu(self.text_projection(sentence_embeddings))
logits = self.classifier(projected_text)
loss = None
cross_entropy_loss = None
contrastive_loss = None
if actual_labels is not None:
loss_fct_ce = nn.CrossEntropyLoss(weight=self.class_weights.to(logits.device) if self.class_weights is not None else None)
cross_entropy_loss = loss_fct_ce(logits.view(-1, self.num_classes), actual_labels.view(-1))
contrastive_loss = self.additional_loss.batch_hard_triplet_loss(embeddings=projected_text, labels=actual_labels)
lambda_contrast = 0.1
loss = cross_entropy_loss + lambda_contrast * contrastive_loss
output = {"logits": logits, "fused_embedding": projected_text}
if loss is not None:
output["loss"] = loss
if cross_entropy_loss is not None:
output["cross_entropy_loss"] = cross_entropy_loss
if contrastive_loss is not None:
output["contrastive_loss"] = contrastive_loss
return output
- Downloads last month
- 13
Model tree for project-droid/DroidDetect-Base
Base model
answerdotai/ModernBERT-base