p3 / app.py
Hem345's picture
Update app.py
2d064e0 verified
import streamlit as st
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
def get_user_data_train():
data_points = []
labels = []
for i in range(5):
x = st.number_input(f"Enter x-coordinate for data point {i + 1}:")
y = st.number_input(f"Enter y-coordinate for data point {i + 1}:")
label = st.text_input(f"Enter label for data point {i + 1}:")
data_points.append([x, y])
labels.append(label)
return np.array(data_points), np.array(labels)
def get_user_data_test():
data_points = []
for i in range(1, 2, 1):
x = st.number_input(f"Enter x-coordinate for test data point {i + 1}:")
y = st.number_input(f"Enter y-coordinate for test data point {i + 1}:")
data_points.append([x, y])
return np.array(data_points)
def knn_classification(X_train, y_train, X_test, k_value):
knn_classifier = KNeighborsClassifier(n_neighbors=k_value)
knn_classifier.fit(X_train, y_train)
predictions = knn_classifier.predict(X_test)
return predictions
def plot_training_and_test_data(X_train, y_train, X_test, predictions):
unique_labels = np.unique(y_train)
# Plot training data
for label in unique_labels:
indices = np.where(y_train == label)
plt.scatter(X_train[indices, 0], X_train[indices, 1], label=f'Training ({label})')
# Plot test data with predicted labels
plt.scatter(X_test[:, 0], X_test[:, 1], label=f'Test (Predicted Labels)', marker='x', c=predictions)
plt.xlabel('X-coordinate')
plt.ylabel('Y-coordinate')
plt.title('Training and Test Data with Predicted Labels')
plt.legend()
st.pyplot()
def main():
st.title("K-Nearest Neighbor Classification App")
# Get user-defined data train
X_train, y_train = get_user_data_train()
# Choose the value of k
k_value = st.slider("Choose the value of k for k-nearest neighbors:", min_value=1, max_value=10, value=3)
# Get user-defined data test
X_test = get_user_data_test()
# Perform k-nearest neighbor classification
predictions = knn_classification(X_train, y_train, X_test, k_value)
# Plot training and test data
plot_training_and_test_data(X_train, y_train, X_test, predictions)
# Display results
st.subheader("Results:")
st.write("User-defined Data Points for Testing:")
st.write(X_test)
st.write(f"\nK-Nearest Neighbor Classification (k={k_value}):")
st.write("Predicted Labels:")
st.write(predictions)
if __name__ == "__main__":
main()