Spaces:
Runtime error
Runtime error
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.layers import Dense, GlobalAveragePooling2D | |
| from tensorflow.keras.optimizers import Adam | |
| # Assuming base_model is your pre-loaded Mixtral 8x7b model | |
| base_model.trainable = False # Freeze the layers | |
| # Add custom layers | |
| x = base_model.output | |
| x = GlobalAveragePooling2D()(x) | |
| x = Dense(1024, activation='relu')(x) | |
| predictions = Dense(num_classes, activation='softmax')(x) | |
| # This is the model we will train | |
| model = Model(inputs=base_model.input, outputs=predictions) | |
| # Compile the model | |
| model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy']) | |
| # Train the model | |
| model.fit(train_dataset, validation_data=val_dataset, epochs=10) | |