On-Device Training: Aufbau einer Android-Anwendung
In diesem Tutorial erfahren Sie, wie Sie eine Android-Anwendung erstellen, die die On-Device-Training-Lösung von ONNX Runtime integriert. On-Device Training bezieht sich auf den Prozess des Trainings eines Machine-Learning-Modells direkt auf einem Edge-Gerät, ohne auf Cloud-Dienste oder externe Server zurückzugreifen.
So wird die Anwendung am Ende dieses Tutorials aussehen

Einleitung
Wir führen Sie durch die Schritte zur Erstellung einer Android-App, die ein einfaches Bildklassifizierungsmodell mithilfe von On-Device-Trainingstechniken trainieren kann. Dieses Tutorial zeigt die Transfer-Learning-Technik, bei der Wissen, das aus dem Training eines Modells für eine Aufgabe gewonnen wurde, genutzt wird, um die Leistung eines Modells für eine andere, aber verwandte Aufgabe zu verbessern. Anstatt den Lernprozess von Grund auf neu zu beginnen, ermöglicht Transfer Learning, dass Wissen oder Merkmale, die von einem vortrainierten Modell gelernt wurden, auf eine neue Aufgabe übertragen werden.
Für dieses Tutorial verwenden wir das Modell MobileNetV2, das auf groß angelegten Bilddatensätzen wie ImageNet (das 1.000 Klassen hat) trainiert wurde. Wir werden dieses Modell verwenden, um benutzerdefinierte Daten in eine von vier Klassen zu klassifizieren. Die anfänglichen Schichten von MobileNetV2 dienen als Merkmalsextraktor und erfassen allgemeine visuelle Merkmale, die für verschiedene Aufgaben anwendbar sind. Nur die letzte Klassifizierungsschicht wird für die jeweilige Aufgabe trainiert.
In diesem Tutorial lernen wir mit Daten zu
- Klassifizieren Sie Tiere in eine von vier Kategorien mithilfe eines vorinstallierten Tierdatensatzes.
- Klassifizieren Sie Prominente in eine von vier Kategorien mithilfe eines benutzerdefinierten Prominentendatensatzes.
Inhaltsverzeichnis
- Inhaltsverzeichnis
- Voraussetzungen
- Offline-Phase – Erstellen der Trainingsartefakte
- Trainingsphase – Android-Anwendungsentwicklung
- Trainingsphase – Ausführen der Anwendung auf einem Gerät
- Schlussfolgerung
Voraussetzungen
Um dieses Tutorial befolgen zu können, sollten Sie über grundlegende Kenntnisse in der Android-App-Entwicklung mit Java oder Kotlin verfügen. Kenntnisse in C++ sowie Kenntnisse von Machine-Learning-Konzepten wie neuronalen Netzen und Bildklassifizierung sind ebenfalls hilfreich.
- Python-Entwicklungsumgebung zur Vorbereitung der Trainingsartefakte
- Android Studio 4.1+
- Android SDK 29+
- Android NDK r21+
- Ein Android-Gerät mit Kamera im Entwicklermodus mit aktiviertem USB-Debugging
Hinweis Die gesamte Android-Anwendung ist auch im GitHub-Repository
onnxruntime-training-examplesverfügbar.
Offline-Phase – Erstellen der Trainingsartefakte
-
Exportieren Sie das Modell nach ONNX.
Wir beginnen mit einem vortrainierten PyTorch-Modell und exportieren es nach ONNX. Das Modell
MobileNetV2wurde auf dem Imagenet-Datensatz vortrainiert, der Daten in 1.000 Kategorien enthält. Für unsere Aufgabe der Bildklassifizierung möchten wir nur Bilder in 4 Klassen klassifizieren. Daher ändern wir die letzte Schicht des Modells so, dass sie 4 Logits anstelle von 1.000 ausgibt.Weitere Details zum Exportieren von PyTorch-Modellen nach ONNX finden Sie hier.
import torch import torchvision model = torchvision.models.mobilenet_v2( weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2) # The original model is trained on imagenet which has 1000 classes. # For our image classification scenario, we need to classify among 4 categories. # So we need to change the last layer of the model to have 4 outputs. model.classifier[1] = torch.nn.Linear(1280, 4) # Export the model to ONNX. model_name = "mobilenetv2" torch.onnx.export(model, torch.randn(1, 3, 224, 224), f"training_artifacts/{model_name}.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}) -
Definieren Sie die trainierbaren und nicht trainierbaren Parameter
import onnx # Load the onnx model. onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx") # Define the parameters that require their gradients to be computed # (trainable parameters) and those that do not (frozen/non trainable parameters). requires_grad = ["classifier.1.weight", "classifier.1.bias"] frozen_params = [ param.name for param in onnx_model.graph.initializer if param.name not in requires_grad ] -
Generieren Sie die Trainingsartefakte.
Wir verwenden für dieses Tutorial den
CrossEntropyLoss-Verlust und denAdamW-Optimierer. Weitere Details zur Artefakterstellung finden Sie hier.from onnxruntime.training import artifacts # Generate the training artifacts. artifacts.generate_artifacts( onnx_model, requires_grad=requires_grad, frozen_params=frozen_params, loss=artifacts.LossType.CrossEntropyLoss, optimizer=artifacts.OptimType.AdamW, artifact_directory="training_artifacts" )Das ist alles! Die Trainingsartefakte wurden im Ordner
training_artifactsgeneriert. Dies markiert das Ende der Offline-Phase. Diese Artefakte sind nun bereit für die Bereitstellung auf dem Android-Gerät für das Training.
Trainingsphase – Android-Anwendungsentwicklung
-
Einrichten des Projekts in Android Studio
a. Öffnen Sie Android Studio und klicken Sie auf
Neues Projekt
b. Klicken Sie auf
Native C++->Weiter. Füllen Sie die Details desNeues Projektwie folgt aus:- Name -
ORT Personalize - Paketname -
com.example.ortpersonalize - Sprache -
Kotlin
Klicken Sie auf
Weiter.
c. Wählen Sie die
C++17-Toolchain ->Fertigstellen
d. Das war's! Das Android Studio-Projekt ist eingerichtet. Sie sollten nun den Android Studio-Editor mit etwas Boilerplate-Code sehen.
- Name -
-
Hinzufügen der ONNX Runtime-Abhängigkeit
a. Erstellen Sie zwei neue Ordner namens
libundinclude\onnxruntimeunter dem cpp-Verzeichnis im Android Studio-Projekt.
b. Gehen Sie zu Maven Central. Gehen Sie zu
Versionen->Durchsuchen-> und laden Sie das Archivpaket (aar-Datei)onnxruntime-training-androidherunter.c. Benennen Sie die Erweiterung
aarinzipum. Also wirdonnxruntime-training-android-1.15.0.aarzuonnxruntime-training-android-1.15.0.zip.d. Entpacken Sie den Inhalt der Zip-Datei.
e. Kopieren Sie die gemeinsam genutzte Bibliothek
libonnxruntime.soaus dem Ordnerjni\arm64-v8ain Ihr Android-Projekt unter den neu erstellten Ordnerlib.f. Kopieren Sie den Inhalt des Ordners
headersin den neu erstellten Ordnerinclude\onnxruntime.g. Fügen Sie in der Datei
native-lib.cppdie Trainings-C++-Header-Datei ein.#include "onnxruntime_training_cxx_api.h"h. Fügen Sie
abiFilterszur Dateibuild.gradle (Modul)hinzu, umarm64-v8aauszuwählen. Diese Einstellung muss unterdefaultConfiginbuild.gradlehinzugefügt werden.ndk { abiFilters 'arm64-v8a' }Beachten Sie, dass der Abschnitt
defaultConfigderbuild.gradle-Datei wie folgt aussehen sollte:defaultConfig { applicationId "com.example.ortpersonalize" minSdk 29 targetSdk 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" externalNativeBuild { cmake { cppFlags '-std=c++17' } } + ndk { + abiFilters 'arm64-v8a' + } }i. Fügen Sie die gemeinsam genutzte Bibliothek
onnxruntimezur DateiCMakeLists.txthinzu, damitcmakedie gemeinsam genutzte Bibliothek finden und dagegen linken kann. Fügen Sie dazu diese Zeilen hinzu, nachdem die BibliothekortpersonalizeinCMakeLists.txthinzugefügt wurde:add_library(onnxruntime SHARED IMPORTED) set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)Lassen Sie
CMakewissen, wo die ONNX Runtime-Header-Dateien zu finden sind, indem Sie diese Zeile direkt nach den obigen beiden Zeilen hinzufügen:target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)Linken Sie das Android C++-Projekt gegen die Bibliothek
onnxruntime, indem Sie die Bibliothekonnxruntimezutarget_link_librarieshinzufügen.target_link_libraries( # Specifies the target library. ortpersonalize # Links the target library to the log library # included in the NDK. ${log-lib} onnxruntime)Beachten Sie, dass die Datei
CMakeLists.txtwie folgt aussehen sollte:project("ortpersonalize") add_library( # Sets the name of the library. ortpersonalize # Sets the library as a shared library. SHARED # Provides a relative path to your source file(s). native-lib.cpp + utils.cpp + inference.cpp + train.cpp) + add_library(onnxruntime SHARED IMPORTED) + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so) + target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime) find_library( # Sets the name of the path variable. log-lib # Specifies the name of the NDK library that # you want CMake to locate. log) target_link_libraries( # Specifies the target library. ortpersonalize # Links the target library to the log library # included in the NDK. ${log-lib} + onnxruntime)j. Erstellen Sie die Anwendung und warten Sie auf Erfolg, um zu bestätigen, dass die App die ONNX Runtime-Header integriert hat und erfolgreich gegen die gemeinsam genutzte ONNX Runtime-Bibliothek linken kann.
-
Verpacken der vorab erstellten Trainingsartefakte und des Datensatzes
a. Erstellen Sie einen neuen Ordner
assetsim Ordnerappim linken Bereich des Android Studio-Projekts, indem Sie mit der rechten Maustaste auf app klicken -> Neu -> Ordner -> Assets-Ordner und ihn unter main platzieren.b. Kopieren Sie die in Schritt 2 generierten Trainingsartefakte in diesen Ordner.
c. Gehen Sie nun zum
onnxruntime-training-examplesRepo und laden Sie den Datensatz (images.zip) auf Ihren Computer herunter und entpacken Sie ihn. Dieser Datensatz wurde aus dem ursprünglichenanimals-10-Datensatz auf Kaggle modifiziert, der von Corrado Alessio erstellt wurde.d. Kopieren Sie den heruntergeladenen Ordner
imagesin das Verzeichnisassets/imagesin Android Studio.Der linke Bereich des Projekts sollte wie folgt aussehen:

-
Schnittstelle mit ONNX Runtime – C++-Code
a. Wir werden die folgenden vier Funktionen in C++ implementieren, die von der Anwendung aufgerufen werden:
createSession: Wird beim Start der Anwendung aufgerufen. Es werden neue Objekte vom TypCheckpointStateundTrainingSessionerstellt.releaseSession: Wird aufgerufen, wenn die Anwendung geschlossen wird. Diese Funktion gibt die zu Beginn der Anwendung zugewiesenen Ressourcen frei.performTraining: Wird aufgerufen, wenn der Benutzer auf die SchaltflächeTrainauf der Benutzeroberfläche klickt.performInference: Wird aufgerufen, wenn der Benutzer auf die SchaltflächeInferauf der Benutzeroberfläche klickt.
b. Sitzung erstellen
Diese Funktion wird beim Start der Anwendung aufgerufen. Sie verwendet die Trainingsartefakte im Assets-Ordner, um die
C++-Objekte CheckpointState und TrainingSession zu erstellen. Diese Objekte werden für das Training des Modells auf dem Gerät verwendet.Die Argumente für
createSessionsind:checkpoint_path: Gespeichertes Pfad zum Checkpoint-Artefakt.train_model_path: Gespeichertes Pfad zum Trainingsmodell-Artefakt.eval_model_path: Gespeichertes Pfad zum Evaluierungsmodell-Artefakt.optimizer_model_path: Gespeichertes Pfad zum Optimierermodell-Artefakt.cache_dir_path: Pfad zum Cache-Verzeichnis auf dem Android-Gerät. Das Cache-Verzeichnis wird als Zugriffsmethode auf die Trainingsartefakte aus dem C++-Code verwendet.
Die Funktion gibt eine
longzurück, die einen Zeiger auf das Objektsession_cachedarstellt. Dieserlongkann inSessionCacheumgewandelt werden, wann immer wir Zugriff auf die Trainingssitzung benötigen.extern "C" JNIEXPORT jlong JNICALL Java_com_example_ortpersonalize_MainActivity_createSession( JNIEnv *env, jobject /* this */, jstring checkpoint_path, jstring train_model_path, jstring eval_model_path, jstring optimizer_model_path, jstring cache_dir_path) { std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>( utils::JString2String(env, checkpoint_path), utils::JString2String(env, train_model_path), utils::JString2String(env, eval_model_path), utils::JString2String(env, optimizer_model_path), utils::JString2String(env, cache_dir_path)); return reinterpret_cast<long>(session_cache.release()); }Wie aus dem obigen Funktionskörper ersichtlich ist, erstellt diese Funktion einen eindeutigen Zeiger auf ein Objekt der Klasse
SessionCache. Die Definition vonSessionCacheist unten angegeben.struct SessionCache { ArtifactPaths artifact_paths; Ort::Env ort_env; Ort::SessionOptions session_options; Ort::CheckpointState checkpoint_state; Ort::TrainingSession training_session; Ort::Session* inference_session; SessionCache(const std::string &checkpoint_path, const std::string &training_model_path, const std::string &eval_model_path, const std::string &optimizer_model_path, const std::string& cache_dir_path) : artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path), ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(), checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())), training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(), artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()), inference_session(nullptr) {} };Die Definition von
ArtifactPathslautet:struct ArtifactPaths { std::string checkpoint_path; std::string training_model_path; std::string eval_model_path; std::string optimizer_model_path; std::string cache_dir_path; std::string inference_model_path; ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path, const std::string &eval_model_path, const std::string &optimizer_model_path, const std::string& cache_dir_path) : checkpoint_path(checkpoint_path), training_model_path(training_model_path), eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path), cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {} };c. Sitzung freigeben
Diese Funktion wird aufgerufen, wenn die Anwendung heruntergefahren wird. Sie gibt die Ressourcen frei, die beim Start der Anwendung erstellt wurden, hauptsächlich CheckpointState und TrainingSession.
Die Argumente für
releaseSessionsind:session:long, das das ObjektSessionCachedarstellt.
extern "C" JNIEXPORT void JNICALL Java_com_example_ortpersonalize_MainActivity_releaseSession( JNIEnv *env, jobject /* this */, jlong session) { auto *session_cache = reinterpret_cast<SessionCache *>(session); delete session_cache->inference_session; delete session_cache; }d. Training durchführen
Diese Funktion wird für jeden Batch aufgerufen, der trainiert werden muss. Die Trainingsschleife ist auf der Anwendungsseite in Kotlin geschrieben, und innerhalb der Trainingsschleife wird die Funktion
performTrainingfür jeden Batch aufgerufen.Die Argumente für
performTrainingsind:session:long, das das ObjektSessionCachedarstellt.batch: Eingabebilder als Float-Array, das für das Training übergeben wird.labels: Labels als Integer-Array, die den für das Training bereitgestellten Eingabebildern zugeordnet sind.batch_size: Anzahl der Bilder, die mit jedemTrainStepverarbeitet werden.channels: Anzahl der Farbkanäle im Bild. Für unser Beispiel wird dieser Wert immer mit3aufgerufen.frame_rows: Anzahl der Zeilen im Bild. Für unser Beispiel wird dieser Wert immer mit224aufgerufen.frame_cols: Anzahl der Spalten im Bild. Für unser Beispiel wird dieser Wert immer mit224aufgerufen.
Die Funktion gibt einen
floatzurück, der den Trainingsverlust für diesen Batch darstellt.extern "C" JNIEXPORT float JNICALL Java_com_example_ortpersonalize_MainActivity_performTraining( JNIEnv *env, jobject /* this */, jlong session, jfloatArray batch, jintArray labels, jint batch_size, jint channels, jint frame_rows, jint frame_cols) { auto* session_cache = reinterpret_cast<SessionCache *>(session); if (session_cache->inference_session) { // Invalidate the inference session since we will be updating the model parameters // in train_step. // The next call to inference session will need to recreate the inference session. delete session_cache->inference_session; session_cache->inference_session = nullptr; } // Update the model parameters using this batch of inputs. return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr), env->GetIntArrayElements(labels, nullptr), batch_size, channels, frame_rows, frame_cols); }Die obige Funktion nutzt die Funktion
train_step. Die Definition der Funktiontrain_steplautet wie folgt:namespace training { float train_step(SessionCache* session_cache, float *batches, int32_t *labels, int64_t batch_size, int64_t image_channels, int64_t image_rows, int64_t image_cols) { const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols}); const std::vector<int64_t> labels_shape({batch_size}); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); std::vector<Ort::Value> user_inputs; // {inputs, labels} // Inputs batched user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches, batch_size * image_channels * image_rows * image_cols * sizeof(float), input_shape.data(), input_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); // Labels batched user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels, batch_size * sizeof(int32_t), labels_shape.data(), labels_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); // Run the train step and execute the forward + loss + backward. float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>()); // Update the model parameters by taking a step in the direction of the gradients computed above. session_cache->training_session.OptimizerStep(); // Reset the gradients now that the parameters have been updated. // New set of gradients can then be computed for the next round of inputs. session_cache->training_session.LazyResetGrad(); return loss; } } // namespace traininge. Inferenz durchführen
Diese Funktion wird aufgerufen, wenn der Benutzer eine Inferenz durchführen möchte.
Die Argumente für
performInferencesind:session:long, das das ObjektSessionCachedarstellt.image_buffer: Eingabebilder als Float-Array, das für das Training übergeben wird.batch_size: Anzahl der Bilder, die bei jeder Inferenz verarbeitet werden. Für unser Beispiel wird dieser Wert immer mit1aufgerufen.image_channels: Anzahl der Farbkanäle im Bild. Für unser Beispiel wird dieser Wert immer mit3aufgerufen.image_rows: Anzahl der Zeilen im Bild. Für unser Beispiel wird dieser Wert immer mit224aufgerufen.image_cols: Anzahl der Spalten im Bild. Für unser Beispiel wird dieser Wert immer mit224aufgerufen.classes: Liste von Zeichenketten, die alle vier benutzerdefinierten Klassen darstellen.
Die Funktion gibt eine
stringzurück, die eine der vier bereitgestellten benutzerdefinierten Klassen darstellt. Dies ist die Vorhersage des Modells.extern "C" JNIEXPORT jstring JNICALL Java_com_example_ortpersonalize_MainActivity_performInference( JNIEnv *env, jobject /* this */, jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows, jint image_cols, jobjectArray classes) { std::vector<std::string> classes_str; for (int i = 0; i < env->GetArrayLength(classes); ++i) { // Access the current string element jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i)); classes_str.push_back(utils::JString2String(env, elem)); } auto* session_cache = reinterpret_cast<SessionCache *>(session); if (!session_cache->inference_session) { // The inference session does not exist, so create a new one. session_cache->training_session.ExportModelForInferencing( session_cache->artifact_paths.inference_model_path.c_str(), {"output"}); session_cache->inference_session = std::make_unique<Ort::Session>( session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(), session_cache->session_options).release(); } auto prediction = inference::classify( session_cache, env->GetFloatArrayElements(image_buffer, nullptr), batch_size, image_channels, image_rows, image_cols, classes_str); return env->NewStringUTF(prediction.first.c_str()); }Die obige Funktion ruft
classifyauf. Die Definition von classify lautet:namespace inference { std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data, int64_t batch_size, int64_t image_channels, int64_t image_rows, int64_t image_cols, const std::vector<std::string>& classes) { std::vector<const char *> input_names = {"input"}; size_t input_count = 1; std::vector<const char *> output_names = {"output"}; size_t output_count = 1; std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols}); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); std::vector<Ort::Value> input_values; // {input images} input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data, batch_size * image_channels * image_rows * image_cols * sizeof(float), input_shape.data(), input_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); std::vector<Ort::Value> output_values; output_values.emplace_back(nullptr); // get the logits session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(), input_count, output_names.data(), output_values.data(), output_count); float *output = output_values.front().GetTensorMutableData<float>(); // run softmax and get the probabilities of each class std::vector<float> probabilities = Softmax(output, classes.size()); size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end())); return {classes[best_index], probabilities[best_index]}; } } // namespace inferenceDie Funktion classify ruft eine weitere Funktion namens
Softmaxauf. Die Definition vonSoftmaxlautet:std::vector<float> Softmax(float *logits, size_t num_logits) { std::vector<float> probabilities(num_logits, 0); float sum = 0; for (size_t i = 0; i < num_logits; ++i) { probabilities[i] = exp(logits[i]); sum += probabilities[i]; } if (sum != 0.0f) { for (size_t i = 0; i < num_logits; ++i) { probabilities[i] /= sum; } } return probabilities; } -
a. Das Modell
MobileNetV2erwartet, dass das bereitgestellte Eingabebild:- Größe
3 x 224 x 224hat. - ein normalisiertes Bild, bei dem der Mittelwert
(0.485, 0.456, 0.406)subtrahiert und durch die Standardabweichung(0.229, 0.224, 0.225)geteilt wurde.
Diese Vorverarbeitung erfolgt in Java/Kotlin unter Verwendung der von Android bereitgestellten Bibliotheken.
Erstellen wir eine neue Datei namens
ImageProcessingUtil.ktim Verzeichnisapp/src/main/java/com/example/ortpersonalize. In dieser Datei fügen wir die Hilfsmethoden zum Zuschneiden und Skalieren sowie zum Normalisieren der Bilder hinzu.b. Zuschneiden und Skalieren des Bildes.
fun processBitmap(bitmap: Bitmap) : Bitmap { // This function processes the given bitmap by // - cropping along the longer dimension to get a square bitmap // If the width is larger than the height // ___+_________________+___ // | + + | // | + + | // | + + + | // | + + | // |__+_________________+__| // <-------- width --------> // <----- height ----> // <--> cropped <--> // // If the height is larger than the width // _________________________ ʌ ʌ // | | | cropped // |+++++++++++++++++++++++| | ʌ v // | | | | // | | | | // | + | height width // | | | | // | | | | // |+++++++++++++++++++++++| | v ʌ // | | | cropped // |_______________________| v v // // // // - resizing the cropped square image to be of size (3 x 224 x 224) as needed by the // mobilenetv2 model. lateinit var bitmapCropped: Bitmap if (bitmap.getWidth() >= bitmap.getHeight()) { // Since height is smaller than the width, we crop a square whose length is the height // So cropping happens along the width dimesion val width: Int = bitmap.getHeight() val height: Int = bitmap.getHeight() // left side of the cropped image must begin at (bitmap.getWidth() / 2 - bitmap.getHeight() / 2) // so that the cropped width contains equal portion of the width on either side of center // top side of the cropped image must begin at 0 since we are not cropping along the height // dimension val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2 val y: Int = 0 bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height) } else { // Since width is smaller than the height, we crop a square whose length is the width // So cropping happens along the height dimesion val width: Int = bitmap.getWidth() val height: Int = bitmap.getWidth() // left side of the cropped image must begin at 0 since we are not cropping along the width // dimension // top side of the cropped image must begin at (bitmap.getHeight() / 2 - bitmap.getWidth() / 2) // so that the cropped height contains equal portion of the height on either side of center val x: Int = 0 val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2 bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height) } // Resize the image to be channels x width x height as needed by the mobilenetv2 model val width: Int = 224 val height: Int = 224 val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false) return bitmapResized }c. Normalisieren des Bildes.
fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) { // This function iterates over the image and performs the following // on the image pixels // - normalizes the pixel values to be between 0 and 1 // - substracts the mean (0.485, 0.456, 0.406) (derived from the mobilenetv2 model configuration) // from the pixel values // - divides by pixel values by the standard deviation (0.229, 0.224, 0.225) (derived from the // mobilenetv2 model configuration) // Values are written to the given buffer starting at the provided offset. // Values are written as follows // |____|____________________|__________________| <--- buffer // ʌ <--- offset // ʌ <--- offset + width * height * channels // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order // |____|______|gggggg|______|__________________| <--- green channel read in column major order // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order val width: Int = bitmap.getWidth() val height: Int = bitmap.getHeight() val stride: Int = width * height for (x in 0 until width) { for (y in 0 until height) { val color: Int = bitmap.getPixel(y, x) val index = offset + (x * height + y) // Subtract the mean and divide by the standard deviation // Values for mean and standard deviation used for // the movilenetv2 model. buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f) buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f) buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f) } } }d. Abrufen eines Bitmaps von einer URI
fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap { // This function reads the image file at the given uri and decodes it to a bitmap val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri) return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true) } - Größe
-
a. Für dieses Tutorial verwenden wir die folgenden Benutzeroberflächenelemente:
- Schaltflächen zum Trainieren und Inferieren
- Klassenschaltflächen
- Statusmeldungs-Text
- Bildanzeige
- Fortschrittsdialog
b. Dieses Tutorial beabsichtigt nicht zu zeigen, wie die grafische Benutzeroberfläche erstellt wird. Aus diesem Grund werden wir einfach die auf GitHub verfügbaren Dateien wiederverwenden.
c. Kopieren Sie alle Zeichenfolldefinitionen aus
strings.xmlin Ihre lokalestrings.xmlin Android Studio.d. Kopieren Sie den Inhalt von
activity_main.xmlin Ihre lokaleactivity_main.xmlin Android Studio.e. Erstellen Sie eine neue Datei im Ordner
layoutnamensdialog.xml. Kopieren Sie den Inhalt vondialog.xmlin Ihre neu erstellte lokaledialog.xmlin Android Studio.f. Die restlichen Änderungen in diesem Abschnitt müssen in der Datei MainActivity.kt vorgenommen werden.
g. Starten der Anwendung
Wenn die Anwendung gestartet wird, wird die Funktion
onCreateaufgerufen. Diese Funktion ist für die Einrichtung des Sitzungscaches und der Benutzeroberflächenhandler verantwortlich.Bitte verweisen Sie auf die Funktion
onCreatein der DateiMainActivity.ktfür den Code.h. Handler für benutzerdefinierte Klassenschaltflächen – Wir möchten die Klassenschaltflächen verwenden, damit Benutzer ihre benutzerdefinierten Bilder zum Trainieren auswählen können. Wir müssen die Listener für diese Schaltflächen hinzufügen, um dies zu tun. Diese Listener erledigen genau das.
Bitte verweisen Sie auf diese Schaltflächenhandler in
MainActivity.kt- onClassAClickedListener
- onClassBClickedListener
- onClassXClickedListener
- onClassYClickedListener
i. Personalisierung der benutzerdefinierten Klassenbeschriftungen
Standardmäßig sind die benutzerdefinierten Klassenbeschriftungen
[A, B, X, Y]. Aber lassen Sie uns Benutzern erlauben, diese Beschriftungen zur Klarheit umzubenennen. Dies wird durch Langklick-Listener erreicht, nämlich (definiert inMainActivity.kt)- onClassALongClickedListener
- onClassBLongClickedListener
- onClassXLongClickedListener
- onClassYLongClickedListener
j. Umschalten der benutzerdefinierten Klassen.
Wenn der Schalter für benutzerdefinierte Klassen deaktiviert ist, wird der vorinstallierte Tiersdatensatz ausgeführt. Wenn er aktiviert ist, wird erwartet, dass der Benutzer seinen eigenen Datensatz zum Trainieren mitbringt. Um diesen Übergang zu handhaben, ist der Schalter-Handler
onCustomClassSettingChangedListenerinMainActivity.ktimplementiert.k. Trainingshandler
Wenn jede Klasse mindestens 1 Bild hat, kann die Schaltfläche
Trainaktiviert werden. Wenn die SchaltflächeTrainangeklickt wird, beginnt das Training für die ausgewählten Bilder. Der Trainingshandler ist verantwortlich für:- Sammeln der Trainingsbilder in einem Container.
- Mischen der Reihenfolge der Bilder.
- Zuschneiden und Skalieren der Bilder.
- Normalisieren der Bilder.
- Batch-Verarbeitung der Bilder.
- Ausführen der Trainingsschleife (Aufrufen der C++-Funktion
performTrainingin einer Schleife).
Die Funktion
onTrainButtonClickedListenerinMainActivity.kterledigt das oben Genannte.l. Inferenzhandler
Nach Abschluss des Trainings kann der Benutzer auf die Schaltfläche
Inferklicken, um eine Inferenz durchzuführen. Der Inferenzhandler ist verantwortlich für:- Sammeln des Inferenzbildes.
- Zuschneiden und Skalieren des Bildes.
- Normalisieren des Bildes.
- Aufrufen der C++-Funktion
performInference. - Melden der inferierten Ausgabe an die Benutzeroberfläche.
Dies wird durch die Funktion
onInferenceButtonClickedListenerinMainActivity.kterreicht.m. Handler für alle oben genannten Aktivitäten
Sobald die Bilder für die Inferenz oder für die benutzerdefinierten Klassen ausgewählt wurden, müssen sie verarbeitet werden. Die Funktion
onActivityResultinMainActivity.kterledigt dies.n. Ein letztes Ding. Fügen Sie das Folgende in die Datei
AndroidManifest.xmlein, um die Kamera zu verwenden:<uses-permission android:name="android.permission.CAMERA" /> <uses-feature android:name="android.hardware.camera" />
Trainingsphase – Ausführen der Anwendung auf einem Gerät
-
Ausführen der Anwendung auf einem Gerät
a. Schließen wir unser Android-Gerät an den Computer an und führen die Anwendung auf dem Gerät aus.
b. Das Starten der Anwendung auf dem Gerät sollte wie folgt aussehen:

-
Training mit einem vorab geladenen Datensatz – Tiere
a. Beginnen wir mit dem Training unter Verwendung des vorab geladenen Tierdatensatzes, indem wir die Anwendung auf dem Gerät starten.
b. Schalten Sie den Schalter
Benutzerdefinierte Klassenunten um.c. Die Klassenbeschriftungen ändern sich zu
Hund,Katze,ElefantundKuh.d. Führen Sie
Trainingaus und warten Sie, bis der Fortschrittsdialog verschwindet (nach Abschluss des Trainings).e. Verwenden Sie nun ein beliebiges Tierbild aus Ihrer Bibliothek zur Inferenz.

Wie aus dem obigen Bild ersichtlich ist, hat das Modell
Kuhkorrekt vorhergesagt. -
Training mit einem benutzerdefinierten Datensatz – Prominente
a. Laden Sie Bilder von Tom Cruise, Leonardo DiCaprio, Ryan Reynolds und Brad Pitt aus dem Internet herunter.
b. Stellen Sie sicher, dass Sie eine neue Sitzung der App starten, indem Sie die App schließen und neu starten.
c. Benennen Sie nach dem Start der Anwendung die vier Klassen durch Langklicken in
Tom,Leo,RyanundBradum.d. Klicken Sie auf die Schaltfläche für jede Klasse und wählen Sie Bilder aus, die mit diesem Prominenten verbunden sind. Wir können etwa 10–15 Bilder pro Kategorie verwenden.
e. Drücken Sie die Schaltfläche
Trainund lassen Sie die Anwendung aus den bereitgestellten Daten lernen.f. Sobald das Training abgeschlossen ist, können wir auf die Schaltfläche
Inferklicken und ein Bild bereitstellen, das die Anwendung noch nicht gesehen hat.g. Das war's! Hoffentlich hat die Anwendung das Bild korrekt klassifiziert.

Fazit
Herzlichen Glückwunsch! Sie haben erfolgreich eine Android-Anwendung erstellt, die Bilder mithilfe von ONNX Runtime auf dem Gerät klassifizieren lernt. Die Anwendung ist auch auf GitHub unter onnxruntime-training-examples verfügbar.