35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
from tflite_model_maker import model_spec
|
|
from tflite_model_maker import image_classifier
|
|
from tflite_model_maker.config import ExportFormat
|
|
from tflite_model_maker.config import QuantizationConfig
|
|
from tflite_model_maker.image_classifier import DataLoader
|
|
|
|
DATA = "data/binary-classification/"
|
|
MODELS = ["mobilenet_v2", "efficientnet_lite3", "efficientnet_lite4"]
|
|
|
|
train_data = DataLoader.from_folder(DATA + "train")
|
|
test_data = DataLoader.from_folder(DATA + "test")
|
|
train_data, valid_data = train_data.split(0.8)
|
|
|
|
|
|
for i in range(len(MODELS)):
|
|
model = image_classifier.create(train_data,
|
|
validation_data=valid_data,
|
|
model_spec=model_spec.get(MODELS[i]),
|
|
epochs=50,
|
|
learning_rate=0.0001,
|
|
dropout_rate=0.2,
|
|
batch_size=64,
|
|
use_augmentation=True)
|
|
model.summary()
|
|
|
|
loss, accuracy = model.evaluate(test_data)
|
|
|
|
config = QuantizationConfig.for_float16()
|
|
filename = f"dermy-binary-classification-{MODELS[i]}.tflite"
|
|
|
|
model.export(export_dir="./models",
|
|
export_format=ExportFormat.TFLITE,
|
|
tflite_filename=filename,
|
|
quantization_config=config)
|