On-Device Training: Aufbau einer Android-Anwendung

In diesem Tutorial erfahren Sie, wie Sie eine Android-Anwendung erstellen, die die On-Device-Training-Lösung von ONNX Runtime integriert. On-Device Training bezieht sich auf den Prozess des Trainings eines Machine-Learning-Modells direkt auf einem Edge-Gerät, ohne auf Cloud-Dienste oder externe Server zurückzugreifen.

So wird die Anwendung am Ende dieses Tutorials aussehen

an image classification app with Tom Cruise in the middle.

Einleitung

Wir führen Sie durch die Schritte zur Erstellung einer Android-App, die ein einfaches Bildklassifizierungsmodell mithilfe von On-Device-Trainingstechniken trainieren kann. Dieses Tutorial zeigt die Transfer-Learning-Technik, bei der Wissen, das aus dem Training eines Modells für eine Aufgabe gewonnen wurde, genutzt wird, um die Leistung eines Modells für eine andere, aber verwandte Aufgabe zu verbessern. Anstatt den Lernprozess von Grund auf neu zu beginnen, ermöglicht Transfer Learning, dass Wissen oder Merkmale, die von einem vortrainierten Modell gelernt wurden, auf eine neue Aufgabe übertragen werden.

Für dieses Tutorial verwenden wir das Modell MobileNetV2, das auf groß angelegten Bilddatensätzen wie ImageNet (das 1.000 Klassen hat) trainiert wurde. Wir werden dieses Modell verwenden, um benutzerdefinierte Daten in eine von vier Klassen zu klassifizieren. Die anfänglichen Schichten von MobileNetV2 dienen als Merkmalsextraktor und erfassen allgemeine visuelle Merkmale, die für verschiedene Aufgaben anwendbar sind. Nur die letzte Klassifizierungsschicht wird für die jeweilige Aufgabe trainiert.

In diesem Tutorial lernen wir mit Daten zu

  • Klassifizieren Sie Tiere in eine von vier Kategorien mithilfe eines vorinstallierten Tierdatensatzes.
  • Klassifizieren Sie Prominente in eine von vier Kategorien mithilfe eines benutzerdefinierten Prominentendatensatzes.

Inhaltsverzeichnis

Voraussetzungen

Um dieses Tutorial befolgen zu können, sollten Sie über grundlegende Kenntnisse in der Android-App-Entwicklung mit Java oder Kotlin verfügen. Kenntnisse in C++ sowie Kenntnisse von Machine-Learning-Konzepten wie neuronalen Netzen und Bildklassifizierung sind ebenfalls hilfreich.

  • Python-Entwicklungsumgebung zur Vorbereitung der Trainingsartefakte
  • Android Studio 4.1+
  • Android SDK 29+
  • Android NDK r21+
  • Ein Android-Gerät mit Kamera im Entwicklermodus mit aktiviertem USB-Debugging

Hinweis Die gesamte Android-Anwendung ist auch im GitHub-Repository onnxruntime-training-examples verfügbar.

Offline-Phase – Erstellen der Trainingsartefakte

  1. Exportieren Sie das Modell nach ONNX.

    Wir beginnen mit einem vortrainierten PyTorch-Modell und exportieren es nach ONNX. Das Modell MobileNetV2 wurde auf dem Imagenet-Datensatz vortrainiert, der Daten in 1.000 Kategorien enthält. Für unsere Aufgabe der Bildklassifizierung möchten wir nur Bilder in 4 Klassen klassifizieren. Daher ändern wir die letzte Schicht des Modells so, dass sie 4 Logits anstelle von 1.000 ausgibt.

    Weitere Details zum Exportieren von PyTorch-Modellen nach ONNX finden Sie hier.

    import torch
    import torchvision
    
    model = torchvision.models.mobilenet_v2(
       weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
    
    # The original model is trained on imagenet which has 1000 classes.
    # For our image classification scenario, we need to classify among 4 categories.
    # So we need to change the last layer of the model to have 4 outputs.
    model.classifier[1] = torch.nn.Linear(1280, 4)
    
    # Export the model to ONNX.
    model_name = "mobilenetv2"
    torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                      f"training_artifacts/{model_name}.onnx",
                      input_names=["input"], output_names=["output"],
                      dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
    
  2. Definieren Sie die trainierbaren und nicht trainierbaren Parameter

    import onnx
    
    # Load the onnx model.
    onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
    
    # Define the parameters that require their gradients to be computed
    # (trainable parameters) and those that do not (frozen/non trainable parameters).
    requires_grad = ["classifier.1.weight", "classifier.1.bias"]
    frozen_params = [
       param.name
       for param in onnx_model.graph.initializer
       if param.name not in requires_grad
    ]
    
  3. Generieren Sie die Trainingsartefakte.

    Wir verwenden für dieses Tutorial den CrossEntropyLoss-Verlust und den AdamW-Optimierer. Weitere Details zur Artefakterstellung finden Sie hier.

    from onnxruntime.training import artifacts
    
    # Generate the training artifacts.
    artifacts.generate_artifacts(
       onnx_model,
       requires_grad=requires_grad,
       frozen_params=frozen_params,
       loss=artifacts.LossType.CrossEntropyLoss,
       optimizer=artifacts.OptimType.AdamW,
       artifact_directory="training_artifacts"
    )
    

    Das ist alles! Die Trainingsartefakte wurden im Ordner training_artifacts generiert. Dies markiert das Ende der Offline-Phase. Diese Artefakte sind nun bereit für die Bereitstellung auf dem Android-Gerät für das Training.

Trainingsphase – Android-Anwendungsentwicklung

  1. Einrichten des Projekts in Android Studio

    a. Öffnen Sie Android Studio und klicken Sie auf Neues Projekt Android Studio Setup - Neues Projekt

    b. Klicken Sie auf Native C++ -> Weiter. Füllen Sie die Details des Neues Projekt wie folgt aus:

    • Name - ORT Personalize
    • Paketname - com.example.ortpersonalize
    • Sprache - Kotlin

    Klicken Sie auf Weiter.

    Android Studio Setup - Project Name

    c. Wählen Sie die C++17-Toolchain -> Fertigstellen

    Android Studio Setup - Project C++ ToolChain

    d. Das war's! Das Android Studio-Projekt ist eingerichtet. Sie sollten nun den Android Studio-Editor mit etwas Boilerplate-Code sehen.

  2. Hinzufügen der ONNX Runtime-Abhängigkeit

    a. Erstellen Sie zwei neue Ordner namens lib und include\onnxruntime unter dem cpp-Verzeichnis im Android Studio-Projekt.

    lib and include folder

    b. Gehen Sie zu Maven Central. Gehen Sie zu Versionen->Durchsuchen-> und laden Sie das Archivpaket (aar-Datei) onnxruntime-training-android herunter.

    c. Benennen Sie die Erweiterung aar in zip um. Also wird onnxruntime-training-android-1.15.0.aar zu onnxruntime-training-android-1.15.0.zip.

    d. Entpacken Sie den Inhalt der Zip-Datei.

    e. Kopieren Sie die gemeinsam genutzte Bibliothek libonnxruntime.so aus dem Ordner jni\arm64-v8a in Ihr Android-Projekt unter den neu erstellten Ordner lib.

    f. Kopieren Sie den Inhalt des Ordners headers in den neu erstellten Ordner include\onnxruntime.

    g. Fügen Sie in der Datei native-lib.cpp die Trainings-C++-Header-Datei ein.

    #include "onnxruntime_training_cxx_api.h"
    

    h. Fügen Sie abiFilters zur Datei build.gradle (Modul) hinzu, um arm64-v8a auszuwählen. Diese Einstellung muss unter defaultConfig in build.gradle hinzugefügt werden.

    ndk {
       abiFilters 'arm64-v8a'
    }
    

    Beachten Sie, dass der Abschnitt defaultConfig der build.gradle-Datei wie folgt aussehen sollte:

     defaultConfig {
        applicationId "com.example.ortpersonalize"
        minSdk 29
        targetSdk 33
        versionCode 1
        versionName "1.0"
    
        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
        externalNativeBuild {
           cmake {
                 cppFlags '-std=c++17'
           }
        }
    +   ndk {
    +       abiFilters 'arm64-v8a'
    +   }
       
     }
    

    i. Fügen Sie die gemeinsam genutzte Bibliothek onnxruntime zur Datei CMakeLists.txt hinzu, damit cmake die gemeinsam genutzte Bibliothek finden und dagegen linken kann. Fügen Sie dazu diese Zeilen hinzu, nachdem die Bibliothek ortpersonalize in CMakeLists.txt hinzugefügt wurde:

    add_library(onnxruntime SHARED IMPORTED)
    set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    

    Lassen Sie CMake wissen, wo die ONNX Runtime-Header-Dateien zu finden sind, indem Sie diese Zeile direkt nach den obigen beiden Zeilen hinzufügen:

    target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    

    Linken Sie das Android C++-Projekt gegen die Bibliothek onnxruntime, indem Sie die Bibliothek onnxruntime zu target_link_libraries hinzufügen.

    target_link_libraries( # Specifies the target library.
         ortpersonalize
    
         # Links the target library to the log library
         # included in the NDK.
         ${log-lib}
    
         onnxruntime)
    

    Beachten Sie, dass die Datei CMakeLists.txt wie folgt aussehen sollte:

    project("ortpersonalize")
    
    add_library( # Sets the name of the library.
          ortpersonalize
    
          # Sets the library as a shared library.
          SHARED
    
          # Provides a relative path to your source file(s).
          native-lib.cpp
    +     utils.cpp
    +     inference.cpp
    +     train.cpp)
    + add_library(onnxruntime SHARED IMPORTED)
    + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    + target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    
    find_library( # Sets the name of the path variable.
          log-lib
    
          # Specifies the name of the NDK library that
          # you want CMake to locate.
          log)
    
    target_link_libraries( # Specifies the target library.
          ortpersonalize
    
          # Links the target library to the log library
          # included in the NDK.
          ${log-lib}
    +     onnxruntime)
    
    

    j. Erstellen Sie die Anwendung und warten Sie auf Erfolg, um zu bestätigen, dass die App die ONNX Runtime-Header integriert hat und erfolgreich gegen die gemeinsam genutzte ONNX Runtime-Bibliothek linken kann.

  3. Verpacken der vorab erstellten Trainingsartefakte und des Datensatzes

    a. Erstellen Sie einen neuen Ordner assets im Ordner app im linken Bereich des Android Studio-Projekts, indem Sie mit der rechten Maustaste auf app klicken -> Neu -> Ordner -> Assets-Ordner und ihn unter main platzieren.

    b. Kopieren Sie die in Schritt 2 generierten Trainingsartefakte in diesen Ordner.

    c. Gehen Sie nun zum onnxruntime-training-examples Repo und laden Sie den Datensatz (images.zip) auf Ihren Computer herunter und entpacken Sie ihn. Dieser Datensatz wurde aus dem ursprünglichen animals-10-Datensatz auf Kaggle modifiziert, der von Corrado Alessio erstellt wurde.

    d. Kopieren Sie den heruntergeladenen Ordner images in das Verzeichnis assets/images in Android Studio.

    Der linke Bereich des Projekts sollte wie folgt aussehen:

    Project Assets

  4. Schnittstelle mit ONNX Runtime – C++-Code

    a. Wir werden die folgenden vier Funktionen in C++ implementieren, die von der Anwendung aufgerufen werden:

    • createSession: Wird beim Start der Anwendung aufgerufen. Es werden neue Objekte vom Typ CheckpointState und TrainingSession erstellt.
    • releaseSession: Wird aufgerufen, wenn die Anwendung geschlossen wird. Diese Funktion gibt die zu Beginn der Anwendung zugewiesenen Ressourcen frei.
    • performTraining: Wird aufgerufen, wenn der Benutzer auf die Schaltfläche Train auf der Benutzeroberfläche klickt.
    • performInference: Wird aufgerufen, wenn der Benutzer auf die Schaltfläche Infer auf der Benutzeroberfläche klickt.

    b. Sitzung erstellen

    Diese Funktion wird beim Start der Anwendung aufgerufen. Sie verwendet die Trainingsartefakte im Assets-Ordner, um die C++-Objekte CheckpointState und TrainingSession zu erstellen. Diese Objekte werden für das Training des Modells auf dem Gerät verwendet.

    Die Argumente für createSession sind:

    • checkpoint_path: Gespeichertes Pfad zum Checkpoint-Artefakt.
    • train_model_path: Gespeichertes Pfad zum Trainingsmodell-Artefakt.
    • eval_model_path: Gespeichertes Pfad zum Evaluierungsmodell-Artefakt.
    • optimizer_model_path: Gespeichertes Pfad zum Optimierermodell-Artefakt.
    • cache_dir_path: Pfad zum Cache-Verzeichnis auf dem Android-Gerät. Das Cache-Verzeichnis wird als Zugriffsmethode auf die Trainingsartefakte aus dem C++-Code verwendet.

    Die Funktion gibt eine long zurück, die einen Zeiger auf das Objekt session_cache darstellt. Dieser long kann in SessionCache umgewandelt werden, wann immer wir Zugriff auf die Trainingssitzung benötigen.

    extern "C" JNIEXPORT jlong JNICALL
    Java_com_example_ortpersonalize_MainActivity_createSession(
          JNIEnv *env, jobject /* this */,
          jstring checkpoint_path, jstring train_model_path, jstring eval_model_path,
          jstring optimizer_model_path, jstring cache_dir_path)
    {
       std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>(
                utils::JString2String(env, checkpoint_path),
                utils::JString2String(env, train_model_path),
                utils::JString2String(env, eval_model_path),
                utils::JString2String(env, optimizer_model_path),
                utils::JString2String(env, cache_dir_path));
       return reinterpret_cast<long>(session_cache.release());
    }
    

    Wie aus dem obigen Funktionskörper ersichtlich ist, erstellt diese Funktion einen eindeutigen Zeiger auf ein Objekt der Klasse SessionCache. Die Definition von SessionCache ist unten angegeben.

    struct SessionCache {
       ArtifactPaths artifact_paths;
       Ort::Env ort_env;
       Ort::SessionOptions session_options;
       Ort::CheckpointState checkpoint_state;
       Ort::TrainingSession training_session;
       Ort::Session* inference_session;
    
       SessionCache(const std::string &checkpoint_path, const std::string &training_model_path,
                   const std::string &eval_model_path, const std::string &optimizer_model_path,
                   const std::string& cache_dir_path) :
       artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path),
       ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(),
       checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())),
       training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(),
                         artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()),
       inference_session(nullptr) {}
    };
    

    Die Definition von ArtifactPaths lautet:

    struct ArtifactPaths {
       std::string checkpoint_path;
       std::string training_model_path;
       std::string eval_model_path;
       std::string optimizer_model_path;
       std::string cache_dir_path;
       std::string inference_model_path;
    
       ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path,
                      const std::string &eval_model_path, const std::string &optimizer_model_path,
                      const std::string& cache_dir_path) :
       checkpoint_path(checkpoint_path), training_model_path(training_model_path),
       eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path),
       cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {}
    };
    

    c. Sitzung freigeben

    Diese Funktion wird aufgerufen, wenn die Anwendung heruntergefahren wird. Sie gibt die Ressourcen frei, die beim Start der Anwendung erstellt wurden, hauptsächlich CheckpointState und TrainingSession.

    Die Argumente für releaseSession sind:

    • session: long, das das Objekt SessionCache darstellt.
    extern "C" JNIEXPORT void JNICALL
    Java_com_example_ortpersonalize_MainActivity_releaseSession(
          JNIEnv *env, jobject /* this */,
          jlong session) {
       auto *session_cache = reinterpret_cast<SessionCache *>(session);
       delete session_cache->inference_session;
       delete session_cache;
    }
    

    d. Training durchführen

    Diese Funktion wird für jeden Batch aufgerufen, der trainiert werden muss. Die Trainingsschleife ist auf der Anwendungsseite in Kotlin geschrieben, und innerhalb der Trainingsschleife wird die Funktion performTraining für jeden Batch aufgerufen.

    Die Argumente für performTraining sind:

    • session: long, das das Objekt SessionCache darstellt.
    • batch: Eingabebilder als Float-Array, das für das Training übergeben wird.
    • labels: Labels als Integer-Array, die den für das Training bereitgestellten Eingabebildern zugeordnet sind.
    • batch_size: Anzahl der Bilder, die mit jedem TrainStep verarbeitet werden.
    • channels: Anzahl der Farbkanäle im Bild. Für unser Beispiel wird dieser Wert immer mit 3 aufgerufen.
    • frame_rows: Anzahl der Zeilen im Bild. Für unser Beispiel wird dieser Wert immer mit 224 aufgerufen.
    • frame_cols: Anzahl der Spalten im Bild. Für unser Beispiel wird dieser Wert immer mit 224 aufgerufen.

    Die Funktion gibt einen float zurück, der den Trainingsverlust für diesen Batch darstellt.

    extern "C"
    JNIEXPORT float JNICALL
    Java_com_example_ortpersonalize_MainActivity_performTraining(
          JNIEnv *env, jobject /* this */,
          jlong session, jfloatArray batch, jintArray labels, jint batch_size,
          jint channels, jint frame_rows, jint frame_cols) {
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
    
       if (session_cache->inference_session) {
          // Invalidate the inference session since we will be updating the model parameters
          // in train_step.
          // The next call to inference session will need to recreate the inference session.
          delete session_cache->inference_session;
          session_cache->inference_session = nullptr;
       }
    
       // Update the model parameters using this batch of inputs.
       return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr),
                                  env->GetIntArrayElements(labels, nullptr), batch_size,
                                  channels, frame_rows, frame_cols);
    }
    

    Die obige Funktion nutzt die Funktion train_step. Die Definition der Funktion train_step lautet wie folgt:

    namespace training {
    
       float train_step(SessionCache* session_cache, float *batches, int32_t *labels,
                         int64_t batch_size, int64_t image_channels, int64_t image_rows,
                         int64_t image_cols) {
          const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
          const std::vector<int64_t> labels_shape({batch_size});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> user_inputs; // {inputs, labels}
          // Inputs batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
          // Labels batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels,
                                                             batch_size * sizeof(int32_t),
                                                             labels_shape.data(), labels_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));
    
          // Run the train step and execute the forward + loss + backward.
          float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>());
    
          // Update the model parameters by taking a step in the direction of the gradients computed above.
          session_cache->training_session.OptimizerStep();
    
          // Reset the gradients now that the parameters have been updated.
          // New set of gradients can then be computed for the next round of inputs.
          session_cache->training_session.LazyResetGrad();
    
          return loss;
       }
    
    } // namespace training
    

    e. Inferenz durchführen

    Diese Funktion wird aufgerufen, wenn der Benutzer eine Inferenz durchführen möchte.

    Die Argumente für performInference sind:

    • session: long, das das Objekt SessionCache darstellt.
    • image_buffer: Eingabebilder als Float-Array, das für das Training übergeben wird.
    • batch_size: Anzahl der Bilder, die bei jeder Inferenz verarbeitet werden. Für unser Beispiel wird dieser Wert immer mit 1 aufgerufen.
    • image_channels: Anzahl der Farbkanäle im Bild. Für unser Beispiel wird dieser Wert immer mit 3 aufgerufen.
    • image_rows: Anzahl der Zeilen im Bild. Für unser Beispiel wird dieser Wert immer mit 224 aufgerufen.
    • image_cols: Anzahl der Spalten im Bild. Für unser Beispiel wird dieser Wert immer mit 224 aufgerufen.
    • classes: Liste von Zeichenketten, die alle vier benutzerdefinierten Klassen darstellen.

    Die Funktion gibt eine string zurück, die eine der vier bereitgestellten benutzerdefinierten Klassen darstellt. Dies ist die Vorhersage des Modells.

    extern "C"
    JNIEXPORT jstring JNICALL
    Java_com_example_ortpersonalize_MainActivity_performInference(
          JNIEnv *env, jobject  /* this */,
          jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows,
          jint image_cols, jobjectArray classes) {
    
       std::vector<std::string> classes_str;
       for (int i = 0; i < env->GetArrayLength(classes); ++i) {
          // Access the current string element
          jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i));
          classes_str.push_back(utils::JString2String(env, elem));
       }
    
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
       if (!session_cache->inference_session) {
          // The inference session does not exist, so create a new one.
          session_cache->training_session.ExportModelForInferencing(
                   session_cache->artifact_paths.inference_model_path.c_str(), {"output"});
          session_cache->inference_session = std::make_unique<Ort::Session>(
                   session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(),
                   session_cache->session_options).release();
       }
    
       auto prediction = inference::classify(
                session_cache, env->GetFloatArrayElements(image_buffer, nullptr),
                batch_size, image_channels, image_rows, image_cols, classes_str);
    
       return env->NewStringUTF(prediction.first.c_str());
    }
    

    Die obige Funktion ruft classify auf. Die Definition von classify lautet:

    namespace inference {
    
       std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data,
                                              int64_t batch_size, int64_t image_channels,
                                              int64_t image_rows, int64_t image_cols,
                                              const std::vector<std::string>& classes) {
          std::vector<const char *> input_names = {"input"};
          size_t input_count = 1;
    
          std::vector<const char *> output_names = {"output"};
          size_t output_count = 1;
    
          std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> input_values; // {input images}
          input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
    
          std::vector<Ort::Value> output_values;
          output_values.emplace_back(nullptr);
    
          // get the logits
          session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(),
                                                 input_count, output_names.data(), output_values.data(), output_count);
    
          float *output = output_values.front().GetTensorMutableData<float>();
    
          // run softmax and get the probabilities of each class
          std::vector<float> probabilities = Softmax(output, classes.size());
          size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end()));
    
          return {classes[best_index], probabilities[best_index]};
       }
    
    } // namespace inference
    

    Die Funktion classify ruft eine weitere Funktion namens Softmax auf. Die Definition von Softmax lautet:

    std::vector<float> Softmax(float *logits, size_t num_logits) {
       std::vector<float> probabilities(num_logits, 0);
       float sum = 0;
       for (size_t i = 0; i < num_logits; ++i) {
             probabilities[i] = exp(logits[i]);
             sum += probabilities[i];
       }
    
       if (sum != 0.0f) {
             for (size_t i = 0; i < num_logits; ++i) {
                probabilities[i] /= sum;
             }
       }
    
       return probabilities;
    }
    
  5. Bildvorverarbeitung

    a. Das Modell MobileNetV2 erwartet, dass das bereitgestellte Eingabebild:

    • Größe 3 x 224 x 224 hat.
    • ein normalisiertes Bild, bei dem der Mittelwert (0.485, 0.456, 0.406) subtrahiert und durch die Standardabweichung (0.229, 0.224, 0.225) geteilt wurde.

    Diese Vorverarbeitung erfolgt in Java/Kotlin unter Verwendung der von Android bereitgestellten Bibliotheken.

    Erstellen wir eine neue Datei namens ImageProcessingUtil.kt im Verzeichnis app/src/main/java/com/example/ortpersonalize. In dieser Datei fügen wir die Hilfsmethoden zum Zuschneiden und Skalieren sowie zum Normalisieren der Bilder hinzu.

    b. Zuschneiden und Skalieren des Bildes.

    fun processBitmap(bitmap: Bitmap) : Bitmap {
       // This function processes the given bitmap by
       //   - cropping along the longer dimension to get a square bitmap
       //     If the width is larger than the height
       //     ___+_________________+___
       //     |  +                 +  |
       //     |  +                 +  |
       //     |  +        +        +  |
       //     |  +                 +  |
       //     |__+_________________+__|
       //     <-------- width -------->
       //        <----- height ---->
       //     <-->      cropped    <-->
       //
       //     If the height is larger than the width
       //     _________________________   ʌ            ʌ
       //     |                       |   |         cropped
       //     |+++++++++++++++++++++++|   |      ʌ     v
       //     |                       |   |      |
       //     |                       |   |      |
       //     |           +           | height width
       //     |                       |   |      |
       //     |                       |   |      |
       //     |+++++++++++++++++++++++|   |      v     ʌ
       //     |                       |   |         cropped
       //     |_______________________|   v            v
       //
       //
       //
       //   - resizing the cropped square image to be of size (3 x 224 x 224) as needed by the
       //     mobilenetv2 model.
       lateinit var bitmapCropped: Bitmap
       if (bitmap.getWidth() >= bitmap.getHeight()) {
          // Since height is smaller than the width, we crop a square whose length is the height
          // So cropping happens along the width dimesion
          val width: Int = bitmap.getHeight()
          val height: Int = bitmap.getHeight()
    
          // left side of the cropped image must begin at (bitmap.getWidth() / 2 - bitmap.getHeight() / 2)
          // so that the cropped width contains equal portion of the width on either side of center
          // top side of the cropped image must begin at 0 since we are not cropping along the height
          // dimension
          val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2
          val y: Int = 0
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       } else {
          // Since width is smaller than the height, we crop a square whose length is the width
          // So cropping happens along the height dimesion
          val width: Int = bitmap.getWidth()
          val height: Int = bitmap.getWidth()
    
          // left side of the cropped image must begin at 0 since we are not cropping along the width
          // dimension
          // top side of the cropped image must begin at (bitmap.getHeight() / 2 - bitmap.getWidth() / 2)
          // so that the cropped height contains equal portion of the height on either side of center
          val x: Int = 0
          val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       }
    
       // Resize the image to be channels x width x height as needed by the mobilenetv2 model
       val width: Int = 224
       val height: Int = 224
       val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false)
    
       return bitmapResized
    }
    

    c. Normalisieren des Bildes.

    fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) {
       // This function iterates over the image and performs the following
       // on the image pixels
       //   - normalizes the pixel values to be between 0 and 1
       //   - substracts the mean (0.485, 0.456, 0.406) (derived from the mobilenetv2 model configuration)
       //     from the pixel values
       //   - divides by pixel values by the standard deviation (0.229, 0.224, 0.225) (derived from the
       //     mobilenetv2 model configuration)
       // Values are written to the given buffer starting at the provided offset.
       // Values are written as follows
       // |____|____________________|__________________| <--- buffer
       //      ʌ                                         <--- offset
       //                           ʌ                    <--- offset + width * height * channels
       // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order
       // |____|______|gggggg|______|__________________| <--- green channel read in column major order
       // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order
    
       val width: Int = bitmap.getWidth()
       val height: Int = bitmap.getHeight()
       val stride: Int = width * height
    
       for (x in 0 until width) {
          for (y in 0 until height) {
                val color: Int = bitmap.getPixel(y, x)
                val index = offset + (x * height + y)
    
                // Subtract the mean and divide by the standard deviation
                // Values for mean and standard deviation used for
                // the movilenetv2 model.
                buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f)
                buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f)
                buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f)
          }
       }
    }
    

    d. Abrufen eines Bitmaps von einer URI

    fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap {
       // This function reads the image file at the given uri and decodes it to a bitmap
       val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri)
       return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true)
    }
    
  6. Anwendungs-Frontend

    a. Für dieses Tutorial verwenden wir die folgenden Benutzeroberflächenelemente:

    • Schaltflächen zum Trainieren und Inferieren
    • Klassenschaltflächen
    • Statusmeldungs-Text
    • Bildanzeige
    • Fortschrittsdialog

    b. Dieses Tutorial beabsichtigt nicht zu zeigen, wie die grafische Benutzeroberfläche erstellt wird. Aus diesem Grund werden wir einfach die auf GitHub verfügbaren Dateien wiederverwenden.

    c. Kopieren Sie alle Zeichenfolldefinitionen aus strings.xml in Ihre lokale strings.xml in Android Studio.

    d. Kopieren Sie den Inhalt von activity_main.xml in Ihre lokale activity_main.xml in Android Studio.

    e. Erstellen Sie eine neue Datei im Ordner layout namens dialog.xml. Kopieren Sie den Inhalt von dialog.xml in Ihre neu erstellte lokale dialog.xml in Android Studio.

    f. Die restlichen Änderungen in diesem Abschnitt müssen in der Datei MainActivity.kt vorgenommen werden.

    g. Starten der Anwendung

    Wenn die Anwendung gestartet wird, wird die Funktion onCreate aufgerufen. Diese Funktion ist für die Einrichtung des Sitzungscaches und der Benutzeroberflächenhandler verantwortlich.

    Bitte verweisen Sie auf die Funktion onCreate in der Datei MainActivity.kt für den Code.

    h. Handler für benutzerdefinierte Klassenschaltflächen – Wir möchten die Klassenschaltflächen verwenden, damit Benutzer ihre benutzerdefinierten Bilder zum Trainieren auswählen können. Wir müssen die Listener für diese Schaltflächen hinzufügen, um dies zu tun. Diese Listener erledigen genau das.

    Bitte verweisen Sie auf diese Schaltflächenhandler in MainActivity.kt

    • onClassAClickedListener
    • onClassBClickedListener
    • onClassXClickedListener
    • onClassYClickedListener

    i. Personalisierung der benutzerdefinierten Klassenbeschriftungen

    Standardmäßig sind die benutzerdefinierten Klassenbeschriftungen [A, B, X, Y]. Aber lassen Sie uns Benutzern erlauben, diese Beschriftungen zur Klarheit umzubenennen. Dies wird durch Langklick-Listener erreicht, nämlich (definiert in MainActivity.kt)

    • onClassALongClickedListener
    • onClassBLongClickedListener
    • onClassXLongClickedListener
    • onClassYLongClickedListener

    j. Umschalten der benutzerdefinierten Klassen.

    Wenn der Schalter für benutzerdefinierte Klassen deaktiviert ist, wird der vorinstallierte Tiersdatensatz ausgeführt. Wenn er aktiviert ist, wird erwartet, dass der Benutzer seinen eigenen Datensatz zum Trainieren mitbringt. Um diesen Übergang zu handhaben, ist der Schalter-Handler onCustomClassSettingChangedListener in MainActivity.kt implementiert.

    k. Trainingshandler

    Wenn jede Klasse mindestens 1 Bild hat, kann die Schaltfläche Train aktiviert werden. Wenn die Schaltfläche Train angeklickt wird, beginnt das Training für die ausgewählten Bilder. Der Trainingshandler ist verantwortlich für:

    • Sammeln der Trainingsbilder in einem Container.
    • Mischen der Reihenfolge der Bilder.
    • Zuschneiden und Skalieren der Bilder.
    • Normalisieren der Bilder.
    • Batch-Verarbeitung der Bilder.
    • Ausführen der Trainingsschleife (Aufrufen der C++-Funktion performTraining in einer Schleife).

    Die Funktion onTrainButtonClickedListener in MainActivity.kt erledigt das oben Genannte.

    l. Inferenzhandler

    Nach Abschluss des Trainings kann der Benutzer auf die Schaltfläche Infer klicken, um eine Inferenz durchzuführen. Der Inferenzhandler ist verantwortlich für:

    • Sammeln des Inferenzbildes.
    • Zuschneiden und Skalieren des Bildes.
    • Normalisieren des Bildes.
    • Aufrufen der C++-Funktion performInference.
    • Melden der inferierten Ausgabe an die Benutzeroberfläche.

    Dies wird durch die Funktion onInferenceButtonClickedListener in MainActivity.kt erreicht.

    m. Handler für alle oben genannten Aktivitäten

    Sobald die Bilder für die Inferenz oder für die benutzerdefinierten Klassen ausgewählt wurden, müssen sie verarbeitet werden. Die Funktion onActivityResult in MainActivity.kt erledigt dies.

    n. Ein letztes Ding. Fügen Sie das Folgende in die Datei AndroidManifest.xml ein, um die Kamera zu verwenden:

    <uses-permission android:name="android.permission.CAMERA" />
    <uses-feature android:name="android.hardware.camera" />
    

Trainingsphase – Ausführen der Anwendung auf einem Gerät

  1. Ausführen der Anwendung auf einem Gerät

    a. Schließen wir unser Android-Gerät an den Computer an und führen die Anwendung auf dem Gerät aus.

    b. Das Starten der Anwendung auf dem Gerät sollte wie folgt aussehen:

    Barebones ORT Personalize app

  2. Training mit einem vorab geladenen Datensatz – Tiere

    a. Beginnen wir mit dem Training unter Verwendung des vorab geladenen Tierdatensatzes, indem wir die Anwendung auf dem Gerät starten.

    b. Schalten Sie den Schalter Benutzerdefinierte Klassen unten um.

    c. Die Klassenbeschriftungen ändern sich zu Hund, Katze, Elefant und Kuh.

    d. Führen Sie Training aus und warten Sie, bis der Fortschrittsdialog verschwindet (nach Abschluss des Trainings).

    e. Verwenden Sie nun ein beliebiges Tierbild aus Ihrer Bibliothek zur Inferenz.

    ORT Personalize app with an image of a cow

    Wie aus dem obigen Bild ersichtlich ist, hat das Modell Kuh korrekt vorhergesagt.

  3. Training mit einem benutzerdefinierten Datensatz – Prominente

    a. Laden Sie Bilder von Tom Cruise, Leonardo DiCaprio, Ryan Reynolds und Brad Pitt aus dem Internet herunter.

    b. Stellen Sie sicher, dass Sie eine neue Sitzung der App starten, indem Sie die App schließen und neu starten.

    c. Benennen Sie nach dem Start der Anwendung die vier Klassen durch Langklicken in Tom, Leo, Ryan und Brad um.

    d. Klicken Sie auf die Schaltfläche für jede Klasse und wählen Sie Bilder aus, die mit diesem Prominenten verbunden sind. Wir können etwa 10–15 Bilder pro Kategorie verwenden.

    e. Drücken Sie die Schaltfläche Train und lassen Sie die Anwendung aus den bereitgestellten Daten lernen.

    f. Sobald das Training abgeschlossen ist, können wir auf die Schaltfläche Infer klicken und ein Bild bereitstellen, das die Anwendung noch nicht gesehen hat.

    g. Das war's! Hoffentlich hat die Anwendung das Bild korrekt klassifiziert.

    an image classification app with Tom Cruise in the middle.

Fazit

Herzlichen Glückwunsch! Sie haben erfolgreich eine Android-Anwendung erstellt, die Bilder mithilfe von ONNX Runtime auf dem Gerät klassifizieren lernt. Die Anwendung ist auch auf GitHub unter onnxruntime-training-examples verfügbar.