This commit is contained in:
2026-05-21 14:59:29 +02:00
parent 0d69f340ae
commit 6a25385409

View File

@@ -271,9 +271,9 @@ x = Dense(512, activation='relu')(x) # Added another Dense layer
x = Dense(256, activation='relu')(x) # Existing Dense layer x = Dense(256, activation='relu')(x) # Existing Dense layer
predictions = Dense(1, activation='sigmoid')(x) # Output layer for binary classification predictions = Dense(1, activation='sigmoid')(x) # Output layer for binary classification
with strategy.scope(): # Use all gpus # with strategy.scope(): # Use all gpus
model = Model(inputs=base_model.input, outputs=predictions) model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy']) model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])
"""## 4. Data Generators """## 4. Data Generators
@@ -340,9 +340,10 @@ model_checkpoint = ModelCheckpoint(
) )
# Train the model # Train the model
print("Training model...")
history = model.fit( history = model.fit(
train_generator, train_generator,
epochs=3, # You can adjust the number of epochs epochs=50, # You can adjust the number of epochs
validation_data=val_generator, validation_data=val_generator,
callbacks=[early_stopping, model_checkpoint], callbacks=[early_stopping, model_checkpoint],
class_weight=class_weights # Use class weights to handle imbalance class_weight=class_weights # Use class weights to handle imbalance