ddroiddesignmoe / training
dnnsdunca's picture
training dataset
9cb6f8b verified
raw
history blame contribute delete
730 Bytes
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
# Assuming base_model is your pre-loaded Mixtral 8x7b model
base_model.trainable = False # Freeze the layers
# Add custom layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
# This is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
# Compile the model
model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
# Train the model
model.fit(train_dataset, validation_data=val_dataset, epochs=10)