Ok now we want to get better accuracy
This commit is contained in:
@@ -29,7 +29,7 @@ Original file is located at
|
||||
import sys
|
||||
IN_COLAB = 'google.colab' in sys.modules
|
||||
|
||||
# if IN_COLAB:
|
||||
#if IN_COLAB:
|
||||
# !pip install pandas numpy matplotlib seaborn pillow scikit-learn tensorflow
|
||||
# !pip install --upgrade kagglehub[pandas-datasets,hf-datasets]
|
||||
|
||||
@@ -203,6 +203,33 @@ val_df['target'] = val_df['target'].astype(str)
|
||||
print("Train target dtype after conversion:", train_df['target'].dtype)
|
||||
print("Validation target dtype after conversion:", val_df['target'].dtype)
|
||||
|
||||
"""### Balancing classes by oversampling the minority class"""
|
||||
|
||||
# Identify majority and minority classes in the training set
|
||||
class_counts = train_df['target'].value_counts()
|
||||
majority_class = class_counts.idxmax()
|
||||
minority_class = class_counts.idxmin()
|
||||
|
||||
# Get the DataFrames for majority and minority classes
|
||||
df_majority = train_df[train_df['target'] == majority_class]
|
||||
df_minority = train_df[train_df['target'] == minority_class]
|
||||
|
||||
# Oversample the minority class
|
||||
df_minority_oversampled = df_minority.sample(
|
||||
class_counts[majority_class], replace=True, random_state=42
|
||||
)
|
||||
|
||||
# Combine majority class with oversampled minority class
|
||||
train_df_balanced = pd.concat([df_majority, df_minority_oversampled])
|
||||
|
||||
# Shuffle the balanced DataFrame
|
||||
train_df_balanced = train_df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
|
||||
|
||||
print("Original train_df class distribution:")
|
||||
print(train_df['target'].value_counts())
|
||||
print("\nBalanced train_df_balanced class distribution:")
|
||||
print(train_df_balanced['target'].value_counts())
|
||||
|
||||
plt.figure(figsize=(6,3))
|
||||
|
||||
sns.countplot(data=df, x='target')
|
||||
@@ -320,17 +347,53 @@ val_generator = val_datagen.flow_from_dataframe(
|
||||
seed=42
|
||||
)
|
||||
|
||||
"""Now, the `train_generator` will use the `train_df_balanced` DataFrame, which has an equal number of samples for both classes. This will help the model learn more effectively from the minority class during training."""
|
||||
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
# Image dimensions
|
||||
IMG_WIDTH = 224
|
||||
IMG_HEIGHT = 224
|
||||
|
||||
# Data generators
|
||||
train_datagen = ImageDataGenerator(
|
||||
preprocessing_function=preprocess_input,
|
||||
rotation_range=20,
|
||||
width_shift_range=0.2,
|
||||
height_shift_range=0.2,
|
||||
shear_range=0.2,
|
||||
zoom_range=0.2,
|
||||
horizontal_flip=True,
|
||||
fill_mode='nearest'
|
||||
)
|
||||
|
||||
val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
|
||||
|
||||
# Flow from balanced dataframe for training
|
||||
train_generator = train_datagen.flow_from_dataframe(
|
||||
dataframe=train_df_balanced, # Use the balanced DataFrame
|
||||
x_col='path',
|
||||
y_col='target',
|
||||
target_size=(IMG_WIDTH, IMG_HEIGHT),
|
||||
batch_size=32,
|
||||
class_mode='binary',
|
||||
seed=42
|
||||
)
|
||||
|
||||
val_generator = val_datagen.flow_from_dataframe(
|
||||
dataframe=val_df,
|
||||
x_col='path',
|
||||
y_col='target',
|
||||
target_size=(IMG_WIDTH, IMG_HEIGHT),
|
||||
batch_size=32,
|
||||
class_mode='binary',
|
||||
seed=42
|
||||
)
|
||||
|
||||
"""## 6. Train the Model"""
|
||||
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
# Callbacks
|
||||
early_stopping = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
)
|
||||
|
||||
model_checkpoint = ModelCheckpoint(
|
||||
'best_model.keras',
|
||||
monitor='val_accuracy',
|
||||
@@ -339,12 +402,11 @@ model_checkpoint = ModelCheckpoint(
|
||||
)
|
||||
|
||||
# Train the model
|
||||
print("Training model...")
|
||||
history = model.fit(
|
||||
train_generator,
|
||||
epochs=50, # You can adjust the number of epochs
|
||||
validation_data=val_generator,
|
||||
callbacks=[early_stopping, model_checkpoint],
|
||||
callbacks=[model_checkpoint],
|
||||
class_weight=class_weights # Use class weights to handle imbalance
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user