updates
This commit is contained in:
parent
be85eaf018
commit
f34f725e32
|
|
@ -1,34 +1,43 @@
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import os
|
import os
|
||||||
|
from safetensors.numpy import save_file
|
||||||
DIR = "data/binary-classification"
|
DIR = "data/binary-classification"
|
||||||
|
|
||||||
#Import Data
|
#Import Data
|
||||||
PATH = os.path.join(os.getcwd(), DIR)
|
PATH = os.path.join(os.getcwd(), DIR)
|
||||||
|
|
||||||
training_data = os.path.join(PATH, "train")
|
training_path = os.path.join(PATH, "train")
|
||||||
validation_data = os.path.join(PATH, "valid")
|
test_path = os.path.join(PATH, "test")
|
||||||
test_data = os.path.join(PATH, "test")
|
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
IMG_SIZE = (224,224)
|
IMG_SIZE = (224,224)
|
||||||
|
|
||||||
#TODO: Import data from both directories, then resplit into test, train, and validation
|
#TODO: Import data from both directories, then resplit into test, train, and validation
|
||||||
|
|
||||||
print(f"Train: {len(training_data)}\nValid: {len(validation_data)}\nTest: {len(test_data)}")
|
training_data = tf.keras.utils.image_dataset_from_directory(training_path,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
image_size=IMG_SIZE,
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="training",
|
||||||
|
seed=1234)
|
||||||
|
validation_data = tf.keras.utils.image_dataset_from_directory(training_path,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
image_size=IMG_SIZE,
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="validation",
|
||||||
|
seed=1234)
|
||||||
|
test_data = tf.keras.utils.image_dataset_from_directory(test_path,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
image_size=IMG_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#View Data
|
#View Data
|
||||||
plt.figure(figsize=(10,10))
|
|
||||||
for images, labels in training_data.take(1):
|
|
||||||
for i in range(9):
|
|
||||||
ax = plt.subplot(3, 3, i+1)
|
|
||||||
plt.imshow(images[i].numpy().astype("uint8"))
|
|
||||||
plt.title(class_names[labels[i]])
|
|
||||||
plt.axis("off")
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
#Init Prefetching
|
#Init Prefetching
|
||||||
|
|
@ -67,10 +76,10 @@ base_model.summary()
|
||||||
global_avg_layer = tf.keras.layers.GlobalAveragePooling2D()
|
global_avg_layer = tf.keras.layers.GlobalAveragePooling2D()
|
||||||
feature_batch_avg = global_avg_layer(feature_batch)
|
feature_batch_avg = global_avg_layer(feature_batch)
|
||||||
|
|
||||||
prediction_layer = tf.keras.layers.Dense(38, activation="softmax")
|
prediction_layer = tf.keras.layers.Dense(2, activation="softmax")
|
||||||
predication_batch = prediction_layer(feature_batch_avg)
|
predication_batch = prediction_layer(feature_batch_avg)
|
||||||
|
|
||||||
inputs = tf.keras.Input(shape=(160,160,3))
|
inputs = tf.keras.Input(shape=IMG_SHAPE)
|
||||||
x = data_augmentation(inputs)
|
x = data_augmentation(inputs)
|
||||||
x = base_model(x, training=False)
|
x = base_model(x, training=False)
|
||||||
x = global_avg_layer(x)
|
x = global_avg_layer(x)
|
||||||
|
|
@ -86,16 +95,14 @@ model.summary()
|
||||||
#Compile the Model
|
#Compile the Model
|
||||||
base_learning_rate = 0.0001
|
base_learning_rate = 0.0001
|
||||||
|
|
||||||
training_data = training_data.map(lambda x,y: (x, tf.one_hot(y,38)))
|
training_data = training_data.map(lambda x,y: (x, tf.one_hot(y,2)))
|
||||||
validation_data = validation_data.map(lambda x,y: (x, tf.one_hot(y,38)))
|
validation_data = validation_data.map(lambda x,y: (x, tf.one_hot(y,2)))
|
||||||
test_data = test_data.map(lambda x,y: (x, tf.one_hot(y,38)))
|
test_data = test_data.map(lambda x,y: (x, tf.one_hot(y,2)))
|
||||||
|
|
||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate)
|
optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate)
|
||||||
loss = tf.keras.losses.CategoricalCrossentropy()
|
loss = tf.keras.losses.CategoricalCrossentropy()
|
||||||
metrics = [tf.keras.metrics.CategoricalAccuracy()]
|
metrics = [tf.keras.metrics.CategoricalAccuracy()]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||||||
|
|
||||||
#Train the Model
|
#Train the Model
|
||||||
|
|
@ -124,7 +131,6 @@ history = model.fit(training_data,
|
||||||
validation_data=validation_data,
|
validation_data=validation_data,
|
||||||
callbacks=[lr_schedule, early_stopping])
|
callbacks=[lr_schedule, early_stopping])
|
||||||
|
|
||||||
model.save("crop-classifier-better-test.keras")
|
|
||||||
|
|
||||||
#Evaluate Model
|
#Evaluate Model
|
||||||
results = model.evaluate(validation_data)
|
results = model.evaluate(validation_data)
|
||||||
|
|
@ -134,3 +140,8 @@ print(f"Validation Accuracy: {results[1]}")
|
||||||
results = model.evaluate(test_data)
|
results = model.evaluate(test_data)
|
||||||
print(f"Test Loss: {results[0]}")
|
print(f"Test Loss: {results[0]}")
|
||||||
print(f"Test Accuracy: {results[1]}")
|
print(f"Test Accuracy: {results[1]}")
|
||||||
|
|
||||||
|
weights = model.get_weights()
|
||||||
|
weights_dict = {f"weight_{i}": w for i, w in enumerate(weights)}
|
||||||
|
|
||||||
|
save_file(weights_dict, "models/mobilenet_v3.safetensors")
|
||||||
|
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
[build-system]
|
|
||||||
requires = ["setuptools", "wheel"]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
[project]
|
|
||||||
name = "dermy-model"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "A Image Classification Model for classifying Moles"
|
|
||||||
authors = [{ name = "r0r5chach", email = "r0r-5chach.xyz@proton.me" }]
|
|
||||||
readme = "README.md"
|
|
||||||
license = { file = "LICENSE" }
|
|
||||||
dependencies = [
|
|
||||||
"matplotlib",
|
|
||||||
"numpy",
|
|
||||||
"tensorflow",
|
|
||||||
]
|
|
||||||
|
|
||||||
[too.setuptools]
|
|
||||||
packages = ["dermy-model"]
|
|
||||||
include_package_data = true
|
|
||||||
Loading…
Reference in New Issue