dermy-models/README.md

3.2 KiB

Dermy Image Classification Model

This repository contains a Python script for training and compiling the dermy image classification model variants using the tflite model maker. The compiled models are stored in the models/ subfolder.

Table of Contents

Installation

To get started, clone this repository and install the required dependencies.

git clone https://vcs.r0r-5chach.xyz/r0r-5chach/dermy-models.git
cd your-repo-name
pip install -r requirements.txt

Usage

To train and compile the binary classification model, run the binary-compile.py script. This script will process the data, train the model, and save the compiled model to the models/ subfolder.

python binary-compile.py

Project Structure

dermy-model/
├── README.md
├── LICENSE
├── requirements.txt
├── binary-compile.py
└── models/
    └── binary-classifiers/
        ├── dermy-efficientnet_lite3.tflite
        ├── dermy-efficientnet_lite4.tflite
        └── dermy-mobilenet_v2.tflite

Data Sources

The data used to train the models in this repository was sourced from Kaggle.

The specific dataset for the binary classifier can be found here

Model Compilation

The binary-compile.py script handles the model training and compilation process. The compiled models are saced in the models/ folder.

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader

DATA = "data/binary-classification/"
MODELS = ["mobilenet_v2", "efficientnet_lite3", "efficientnet_lite4"]

train_data = DataLoader.from_folder(DATA + "train")
test_data = DataLoader.from_folder(DATA + "test")
train_data, valid_data = train_data.split(0.8)


for i in range(len(MODELS)):
    model = image_classifier.create(train_data,
                                    validation_data=valid_data,
                                    model_spec=model_spec.get(MODELS[i]),
                                    epochs=50,
                                    learning_rate=0.0001,
                                    dropout_rate=0.2,
                                    batch_size=64,
                                    use_augmentation=True)
    model.summary()

    loss, accuracy = model.evaluate(test_data)

    config = QuantizationConfig.for_float16()
    filename = f"dermy-binary-classification-{MODELS[i]}.tflite"

    model.export(export_dir="./models",
                 export_format=ExportFormat.TFLITE,
                 tflite_filename=filename,
                 quantization_config=config)

License

This project is licensed under the GNU GPLv3 License. See the LICENSE file for more details.