Commit
·
9a09d32
1
Parent(s):
e489d41
Update Sniffer_AI.py
Browse files- Sniffer_AI.py +46 -19
Sniffer_AI.py
CHANGED
@@ -13,6 +13,19 @@ dt_model = joblib.load('decision_tree_model.pkl')
|
|
13 |
bagging_model = joblib.load('model_bagging.pkl')
|
14 |
ada_model = joblib.load('model_adaboost.pkl')
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class_labels = {
|
17 |
0: "normal",
|
18 |
1: "backdoor",
|
@@ -26,39 +39,53 @@ class_labels = {
|
|
26 |
9: "mitm"
|
27 |
}
|
28 |
|
29 |
-
def detect_intrusion(
|
30 |
-
#
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
# Choose the model based on user selection
|
34 |
if model_choice == "Random Forest":
|
35 |
model = rf_model
|
36 |
elif model_choice == "Decision Tree":
|
37 |
-
model =
|
38 |
elif model_choice == "Bagging Classifier":
|
39 |
-
model =
|
40 |
elif model_choice == "AdaBoost Classifier":
|
41 |
-
model =
|
42 |
else:
|
43 |
return "Invalid model choice!"
|
44 |
-
|
45 |
# Predict the class (multi-class classification)
|
46 |
-
prediction = model.predict(
|
47 |
-
predicted_class = prediction[0] # Get the predicted class (an integer between 0-
|
48 |
-
|
49 |
-
#
|
50 |
if predicted_class == 0:
|
51 |
return "No Intrusion Detected"
|
52 |
else:
|
53 |
return f"Intrusion Detected: {class_labels.get(predicted_class, 'Unknown Attack')}"
|
54 |
|
55 |
-
# Create
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
# Launch the interface locally for testing
|
64 |
iface.launch()
|
|
|
13 |
bagging_model = joblib.load('model_bagging.pkl')
|
14 |
ada_model = joblib.load('model_adaboost.pkl')
|
15 |
|
16 |
+
# Define the feature names
|
17 |
+
feature_names = [
|
18 |
+
"src_ip", "src_port", "dst_ip", "dst_port", "proto", "service", "duration",
|
19 |
+
"src_bytes", "dst_bytes", "conn_state", "missed_bytes", "src_pkts",
|
20 |
+
"src_ip_bytes", "dst_pkts", "dst_ip_bytes", "dns_query", "dns_qclass",
|
21 |
+
"dns_qtype", "dns_rcode", "dns_AA", "dns_RD", "dns_RA", "dns_rejected",
|
22 |
+
"ssl_version", "ssl_cipher", "ssl_resumed", "ssl_established", "ssl_subject",
|
23 |
+
"ssl_issuer", "http_trans_depth", "http_method", "http_uri", "http_version",
|
24 |
+
"http_request_body_len", "http_response_body_len", "http_status_code",
|
25 |
+
"http_user_agent", "http_orig_mime_types", "http_resp_mime_types",
|
26 |
+
"weird_name", "weird_addl", "weird_notice", "label"
|
27 |
+
]
|
28 |
+
|
29 |
class_labels = {
|
30 |
0: "normal",
|
31 |
1: "backdoor",
|
|
|
39 |
9: "mitm"
|
40 |
}
|
41 |
|
42 |
+
def detect_intrusion(feature_values, model_choice="Random Forest"):
|
43 |
+
# Ensure the length of feature_values matches feature_names
|
44 |
+
if len(feature_values) != len(feature_names):
|
45 |
+
return "Please fill in all the required feature values."
|
46 |
+
|
47 |
+
# Convert the input values to floats and match them with feature names
|
48 |
+
try:
|
49 |
+
feature_values = [float(value) for value in feature_values]
|
50 |
+
except ValueError:
|
51 |
+
return "Please enter valid numerical values for all fields."
|
52 |
+
|
53 |
# Choose the model based on user selection
|
54 |
if model_choice == "Random Forest":
|
55 |
model = rf_model
|
56 |
elif model_choice == "Decision Tree":
|
57 |
+
model = dt_model
|
58 |
elif model_choice == "Bagging Classifier":
|
59 |
+
model = bagging_model
|
60 |
elif model_choice == "AdaBoost Classifier":
|
61 |
+
model = ada_model
|
62 |
else:
|
63 |
return "Invalid model choice!"
|
64 |
+
|
65 |
# Predict the class (multi-class classification)
|
66 |
+
prediction = model.predict([feature_values])
|
67 |
+
predicted_class = prediction[0] # Get the predicted class (an integer between 0-9)
|
68 |
+
|
69 |
+
# Notify the user of the detected attack or normal traffic
|
70 |
if predicted_class == 0:
|
71 |
return "No Intrusion Detected"
|
72 |
else:
|
73 |
return f"Intrusion Detected: {class_labels.get(predicted_class, 'Unknown Attack')}"
|
74 |
|
75 |
+
# Create Gradio input fields for each feature
|
76 |
+
inputs = [gr.Textbox(label=feature_name) for feature_name in feature_names[:-1]] # Exclude "label" field from inputs
|
77 |
+
|
78 |
+
# Add model choice dropdown
|
79 |
+
inputs.append(gr.Dropdown(choices=["Random Forest", "Decision Tree", "Bagging Classifier", "AdaBoost Classifier"], label="Select Model"))
|
80 |
+
|
81 |
+
# Create the Gradio interface
|
82 |
+
iface = gr.Interface(
|
83 |
+
fn=detect_intrusion,
|
84 |
+
inputs=inputs,
|
85 |
+
outputs="text",
|
86 |
+
title="Intrusion Detection System",
|
87 |
+
description="Fill in the blank fields for the network traffic features, and choose the model to detect intrusions."
|
88 |
+
)
|
89 |
|
90 |
# Launch the interface locally for testing
|
91 |
iface.launch()
|