Ok now we want to get better accuracy
This commit is contained in:
@@ -203,6 +203,33 @@ val_df['target'] = val_df['target'].astype(str)
|
|||||||
print("Train target dtype after conversion:", train_df['target'].dtype)
|
print("Train target dtype after conversion:", train_df['target'].dtype)
|
||||||
print("Validation target dtype after conversion:", val_df['target'].dtype)
|
print("Validation target dtype after conversion:", val_df['target'].dtype)
|
||||||
|
|
||||||
|
"""### Balancing classes by oversampling the minority class"""
|
||||||
|
|
||||||
|
# Identify majority and minority classes in the training set
|
||||||
|
class_counts = train_df['target'].value_counts()
|
||||||
|
majority_class = class_counts.idxmax()
|
||||||
|
minority_class = class_counts.idxmin()
|
||||||
|
|
||||||
|
# Get the DataFrames for majority and minority classes
|
||||||
|
df_majority = train_df[train_df['target'] == majority_class]
|
||||||
|
df_minority = train_df[train_df['target'] == minority_class]
|
||||||
|
|
||||||
|
# Oversample the minority class
|
||||||
|
df_minority_oversampled = df_minority.sample(
|
||||||
|
class_counts[majority_class], replace=True, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine majority class with oversampled minority class
|
||||||
|
train_df_balanced = pd.concat([df_majority, df_minority_oversampled])
|
||||||
|
|
||||||
|
# Shuffle the balanced DataFrame
|
||||||
|
train_df_balanced = train_df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
|
||||||
|
|
||||||
|
print("Original train_df class distribution:")
|
||||||
|
print(train_df['target'].value_counts())
|
||||||
|
print("\nBalanced train_df_balanced class distribution:")
|
||||||
|
print(train_df_balanced['target'].value_counts())
|
||||||
|
|
||||||
plt.figure(figsize=(6,3))
|
plt.figure(figsize=(6,3))
|
||||||
|
|
||||||
sns.countplot(data=df, x='target')
|
sns.countplot(data=df, x='target')
|
||||||
@@ -320,17 +347,53 @@ val_generator = val_datagen.flow_from_dataframe(
|
|||||||
seed=42
|
seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""Now, the `train_generator` will use the `train_df_balanced` DataFrame, which has an equal number of samples for both classes. This will help the model learn more effectively from the minority class during training."""
|
||||||
|
|
||||||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
# Image dimensions
|
||||||
|
IMG_WIDTH = 224
|
||||||
|
IMG_HEIGHT = 224
|
||||||
|
|
||||||
|
# Data generators
|
||||||
|
train_datagen = ImageDataGenerator(
|
||||||
|
preprocessing_function=preprocess_input,
|
||||||
|
rotation_range=20,
|
||||||
|
width_shift_range=0.2,
|
||||||
|
height_shift_range=0.2,
|
||||||
|
shear_range=0.2,
|
||||||
|
zoom_range=0.2,
|
||||||
|
horizontal_flip=True,
|
||||||
|
fill_mode='nearest'
|
||||||
|
)
|
||||||
|
|
||||||
|
val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
|
||||||
|
|
||||||
|
# Flow from balanced dataframe for training
|
||||||
|
train_generator = train_datagen.flow_from_dataframe(
|
||||||
|
dataframe=train_df_balanced, # Use the balanced DataFrame
|
||||||
|
x_col='path',
|
||||||
|
y_col='target',
|
||||||
|
target_size=(IMG_WIDTH, IMG_HEIGHT),
|
||||||
|
batch_size=32,
|
||||||
|
class_mode='binary',
|
||||||
|
seed=42
|
||||||
|
)
|
||||||
|
|
||||||
|
val_generator = val_datagen.flow_from_dataframe(
|
||||||
|
dataframe=val_df,
|
||||||
|
x_col='path',
|
||||||
|
y_col='target',
|
||||||
|
target_size=(IMG_WIDTH, IMG_HEIGHT),
|
||||||
|
batch_size=32,
|
||||||
|
class_mode='binary',
|
||||||
|
seed=42
|
||||||
|
)
|
||||||
|
|
||||||
"""## 6. Train the Model"""
|
"""## 6. Train the Model"""
|
||||||
|
|
||||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
early_stopping = EarlyStopping(
|
|
||||||
monitor='val_loss',
|
|
||||||
patience=10,
|
|
||||||
restore_best_weights=True
|
|
||||||
)
|
|
||||||
|
|
||||||
model_checkpoint = ModelCheckpoint(
|
model_checkpoint = ModelCheckpoint(
|
||||||
'best_model.keras',
|
'best_model.keras',
|
||||||
monitor='val_accuracy',
|
monitor='val_accuracy',
|
||||||
@@ -339,12 +402,11 @@ model_checkpoint = ModelCheckpoint(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Train the model
|
# Train the model
|
||||||
print("Training model...")
|
|
||||||
history = model.fit(
|
history = model.fit(
|
||||||
train_generator,
|
train_generator,
|
||||||
epochs=50, # You can adjust the number of epochs
|
epochs=50, # You can adjust the number of epochs
|
||||||
validation_data=val_generator,
|
validation_data=val_generator,
|
||||||
callbacks=[early_stopping, model_checkpoint],
|
callbacks=[model_checkpoint],
|
||||||
class_weight=class_weights # Use class weights to handle imbalance
|
class_weight=class_weights # Use class weights to handle imbalance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user