dermy-models/binary-compile.py

35 lines
1.3 KiB
Python
Raw Normal View History

2024-06-03 16:27:04 +00:00
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
DATA = "data/binary-classification/"
MODELS = ["mobilenet_v2", "efficientnet_lite3", "efficientnet_lite4"]
train_data = DataLoader.from_folder(DATA + "train")
test_data = DataLoader.from_folder(DATA + "test")
train_data, valid_data = train_data.split(0.8)
for i in range(len(MODELS)):
model = image_classifier.create(train_data,
validation_data=valid_data,
model_spec=model_spec.get(MODELS[i]),
epochs=50,
learning_rate=0.0001,
dropout_rate=0.2,
batch_size=64,
use_augmentation=True)
model.summary()
loss, accuracy = model.evaluate(test_data)
config = QuantizationConfig.for_float16()
filename = f"dermy-binary-classification-{MODELS[i]}.tflite"
model.export(export_dir="./models",
export_format=ExportFormat.TFLITE,
tflite_filename=filename,
quantization_config=config)