csy3025-assignment-2/main.py

149 lines
4.2 KiB
Python
Raw Permalink Normal View History

2024-05-18 15:22:07 +00:00
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
2024-05-19 22:06:27 +00:00
2024-05-18 15:22:07 +00:00
#Import Data
PATH = os.path.join(os.getcwd(), "crop-data")
training_data = os.path.join(PATH, "train")
validation_data = os.path.join(PATH, "valid")
test_data = os.path.join(PATH, "test")
BATCH_SIZE = 64
IMG_SIZE = (160,160)
training_data = tf.keras.utils.image_dataset_from_directory(training_data,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
validation_data = tf.keras.utils.image_dataset_from_directory(validation_data,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
2024-05-19 22:06:27 +00:00
train_size = len(training_data)-len(validation_data)
test_data = training_data.skip(train_size).take(len(validation_data))
2024-05-18 15:22:07 +00:00
class_names = training_data.class_names
2024-05-19 22:06:27 +00:00
training_data = training_data.take(train_size)
print(f"Train: {len(training_data)}\nValid: {len(validation_data)}\nTest: {len(test_data)}")
2024-05-18 15:22:07 +00:00
#View Data
plt.figure(figsize=(10,10))
for images, labels in training_data.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i+1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
plt.show()
#Init Prefetching
AUTOTUNE = tf.data.AUTOTUNE
training_data = training_data.prefetch(buffer_size=AUTOTUNE)
validation_data = validation_data.prefetch(buffer_size=AUTOTUNE)
test_data = test_data.prefetch(buffer_size=AUTOTUNE)
#Data Augmentation
data_augmentation = tf.keras.Sequential([
2024-05-19 16:31:49 +00:00
tf.keras.layers.RandomFlip('horizontal'),
2024-05-18 15:22:07 +00:00
tf.keras.layers.RandomRotation(0.2)
])
#Create Base Model From MobileNetV3
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV3Large(
input_shape=IMG_SHAPE,
include_top=False,
weights="imagenet"
)
image_batch, label_batch = next(iter(training_data))
feature_batch = base_model(image_batch)
base_model.trainable = False
#View Base Model
base_model.summary()
#Add Classification Header
global_avg_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_avg = global_avg_layer(feature_batch)
prediction_layer = tf.keras.layers.Dense(38, activation="softmax")
predication_batch = prediction_layer(feature_batch_avg)
inputs = tf.keras.Input(shape=(160,160,3))
x = data_augmentation(inputs)
x = base_model(x, training=False)
x = global_avg_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
#View Model with Classification Head
model.summary()
#Compile the Model
base_learning_rate = 0.0001
training_data = training_data.map(lambda x,y: (x, tf.one_hot(y,38)))
validation_data = validation_data.map(lambda x,y: (x, tf.one_hot(y,38)))
2024-05-19 22:06:27 +00:00
test_data = test_data.map(lambda x,y: (x, tf.one_hot(y,38)))
2024-05-18 15:22:07 +00:00
optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate)
loss = tf.keras.losses.CategoricalCrossentropy()
metrics = [tf.keras.metrics.CategoricalAccuracy()]
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
#Train the Model
initial_epochs = 50
loss0, accuracy0 = model.evaluate(validation_data)
print(f"initial loss: {loss0}")
print(f"initial accuracy: {accuracy0}")
lr_schedule = tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.1,
patience=5,
min_lr=1e-6
)
2024-05-19 22:06:27 +00:00
early_stopping = tf.keras.callbacks.EarlyStopping(
2024-05-18 15:22:07 +00:00
monitor="val_loss",
patience=10,
restore_best_weights=True
)
history = model.fit(training_data,
epochs=initial_epochs,
validation_data=validation_data,
2024-05-19 22:06:27 +00:00
callbacks=[lr_schedule, early_stopping])
model.save("crop-classifier-better-test.keras")
#Evaluate Model
results = model.evaluate(validation_data)
print(f"Validation Loss: {results[0]}")
print(f"Validation Accuracy: {results[1]}")
2024-05-18 15:22:07 +00:00
2024-05-19 22:06:27 +00:00
results = model.evaluate(test_data)
print(f"Test Loss: {results[0]}")
print(f"Test Accuracy: {results[1]}")