Bildauflösung mit Machine-Learning-Super-Resolution auf Mobilgeräten verbessern

Erfahren Sie, wie Sie eine Anwendung zur Verbesserung der Bildauflösung mit ONNX Runtime Mobile erstellen, mit einem Modell, das Vor- und Nachbearbeitung enthält.

Sie können dieses Tutorial verwenden, um die Anwendung für Android oder iOS zu erstellen.

Die Anwendung nimmt ein Bild als Eingabe, führt die Super-Resolution-Operation aus, wenn die Schaltfläche geklickt wird, und zeigt das Bild mit verbesserter Auflösung darunter an, wie im folgenden Screenshot dargestellt.

Super resolution on a cat

Inhalt

Modell vorbereiten

Das in diesem Tutorial verwendete Machine-Learning-Modell basiert auf dem im PyTorch-Tutorial am Ende dieser Seite genannten Modell.

Wir stellen ein praktisches Python-Skript zur Verfügung, das das PyTorch-Modell in das ONNX-Format exportiert und Vor- und Nachbearbeitung hinzufügt.

  1. Installieren Sie vor dem Ausführen dieses Skripts die folgenden Python-Pakete

     pip install torch
     pip install pillow
     pip install onnx
     pip install onnxruntime
     pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions
    

    Hinweis zu Versionen: Die besten Super-Resolution-Ergebnisse werden mit ONNX Opset 18 (mit Unterstützung für den Resize-Operator mit Antialiasing) erzielt, der von ONNX 1.13.0 und ONNX Runtime 1.14.0 und neuer unterstützt wird. Das Paket onnxruntime-extensions ist eine Vorabversion. Die Release-Version wird bald verfügbar sein.

  2. Laden Sie dann das Skript und das Testbild aus dem onnxruntime-extensions GitHub-Repository herunter (falls Sie dieses Repository noch nicht geklont haben).

     curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/superresolution_e2e.py > superresolution_e2e.py
     curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/data/super_res_input.png > data/super_res_input.png
    
  3. Führen Sie das Skript aus, um das Kernmodell zu exportieren und Vor- und Nachbearbeitung hinzuzufügen.

     python superresolution_e2e.py 
    

Nachdem das Skript ausgeführt wurde, sollten Sie zwei ONNX-Dateien im Verzeichnis sehen, in dem Sie das Skript ausgeführt haben.

pytorch_superresolution.onnx
pytorch_superresolution_with_pre_and_post_processing.onnx

Wenn Sie die beiden Modelle in Netron laden, können Sie den Unterschied bei Eingaben und Ausgaben zwischen beiden sehen. Die ersten beiden Bilder unten zeigen das Originalmodell mit seinen Eingaben als Stapel von Kanal-Daten, und die zweiten beiden zeigen die Eingaben und Ausgaben als Bild-Bytes.

ONNX model without pre and post processing

ONNX model inputs and outputs without pre and post processing

ONNX model with pre and post processing

ONNX model inputs and outputs with pre and post processing

Jetzt ist es an der Zeit, den Anwendungscode zu schreiben.

Android-App

Voraussetzungen

  • Android Studio Dolphin 2021.3.1 Patch + (installiert unter Mac/Windows/Linux)
  • Android SDK 29+
  • Android NDK r22+
  • Ein Android-Gerät oder ein Android-Emulator

Beispielcode

Den vollständigen Quellcode für die Android-Super-Resolution-App finden Sie auf GitHub.

Um die App aus dem Quellcode auszuführen, klonen Sie das obige Repository und laden Sie die Datei build.gradle in Android Studio, bauen und führen Sie sie aus!

Um die App Schritt für Schritt zu erstellen, folgen Sie den folgenden Abschnitten.

Code von Grund auf

Projekt einrichten

Erstellen Sie in Android Studio ein neues Projekt für Phone und Tablet und wählen Sie die Blank-Vorlage aus. Nennen Sie die Anwendung super_resolution oder ähnlich.

Abhängigkeiten

Fügen Sie die folgenden Abhängigkeiten zur build.gradle der App hinzu.

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'

Projektressourcen

  1. Fügen Sie die Modelldatei als Rohressource hinzu.

    Erstellen Sie einen Ordner namens raw im Ordner src/main/res und verschieben oder kopieren Sie das ONNX-Modell in den Raw-Ordner.

  2. Fügen Sie das Testbild als Asset hinzu.

    Erstellen Sie einen Ordner namens assets im Hauptprojektordner und kopieren Sie das Bild, auf das Sie Super Resolution anwenden möchten, in diesen Ordner mit dem Dateinamen test_superresolution.png.

Hauptanwendungsklassen-Code

Erstellen Sie eine Datei namens MainActivity.kt und fügen Sie ihr die folgenden Codefragmente hinzu.

  1. Importanweisungen hinzufügen

    import ai.onnxruntime.*
    import ai.onnxruntime.extensions.OrtxPackage
    import android.annotation.SuppressLint
    import android.os.Bundle
    import android.widget.Button
    import android.widget.ImageView
    import android.widget.Toast
    import androidx.activity.*
    import androidx.appcompat.app.AppCompatActivity
    import kotlinx.android.synthetic.main.activity_main.*
    import kotlinx.coroutines.*
    import java.io.InputStream
    import java.util.*
    import java.util.concurrent.ExecutorService
    import java.util.concurrent.Executors
    
  2. Erstellen Sie die Hauptaktivitätsklasse und fügen Sie die Klassenvariablen hinzu.

    class MainActivity : AppCompatActivity() {
        private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
        private lateinit var ortSession: OrtSession
        private var inputImage: ImageView? = null
        private var outputImage: ImageView? = null
        private var superResolutionButton: Button? = null
    
        ...
    }
    
  3. Fügen Sie die onCreate()-Methode hinzu.

    Hier initialisieren wir die ONNX Runtime Session. Eine Session speichert einen Verweis auf das Modell, das zur Durchführung von Inferenz in der Anwendung verwendet wird. Sie enthält auch einen Parameter für Sitzungsoptionen, in dem Sie verschiedene Ausführungsanbieter (Hardwarebeschleuniger wie NNAPI) angeben können. In diesem Fall verwenden wir standardmäßig die CPU. Wir registrieren jedoch die benutzerdefinierte OP-Bibliothek, in der die Bildkodierungs- und Dekodierungsoperatoren am Ein- und Ausgang des Modells gefunden werden.

     override fun onCreate(savedInstanceState: Bundle?) {
         super.onCreate(savedInstanceState)
         setContentView(R.layout.activity_main)
    
         inputImage = findViewById(R.id.imageView1)
         outputImage = findViewById(R.id.imageView2);
         superResolutionButton = findViewById(R.id.super_resolution_button)
         inputImage?.setImageBitmap(
             BitmapFactory.decodeStream(readInputImage())
         );
    
         // Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators.
         // Note: These are used to decode the input image into the format the original model requires,
         // and to encode the model output into png format
         val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions()
         sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())
         ortSession = ortEnv.createSession(readModel(), sessionOptions)
    
         superResolutionButton?.setOnClickListener {
             try {
                 performSuperResolution(ortSession)
                 Toast.makeText(baseContext, "Super resolution performed!", Toast.LENGTH_SHORT)
                     .show()
             } catch (e: Exception) {
                 Log.e(TAG, "Exception caught when perform super resolution", e)
                 Toast.makeText(baseContext, "Failed to perform super resolution", Toast.LENGTH_SHORT)
                     .show()
             }
         }
     }
    
  4. onDestroy-Methode hinzufügen

     override fun onDestroy() {
         super.onDestroy()
         ortEnv.close()
         ortSession.close()
     }
    
    
  5. updateUI-Methode hinzufügen

    private fun updateUI(result: Result) {
        outputImage?.setImageBitmap(result.outputBitmap)
    }
    
  6. readModel-Methode hinzufügen

    Diese Methode liest das ONNX-Modell aus dem Ressourcenordner.

    private fun readModel(): ByteArray {
        val modelID = R.pytorch_superresolution_with_pre_post_processing_op18
        return resources.openRawResource(modelID).readBytes()
    }   
    
  7. Methode zum Lesen des Eingabebilds hinzufügen

    Diese Methode liest ein Testbild aus dem Assets-Ordner. Aktuell liest sie ein festes Bild, das in die Anwendung integriert ist. Das Beispiel wird bald erweitert, um das Bild direkt von der Kamera oder der Fotobibliothek zu lesen.

    private fun readInputImage(): InputStream {
        return assets.open("test_superresolution.png")
    }   
    
  8. Methode zur Durchführung der Inferenz hinzufügen

    Diese Methode ruft die Methode auf, die das Herzstück der Anwendung ist: SuperResPerformer.upscale(), die Methode, die die Inferenz für das Modell ausführt. Der Code dafür ist im nächsten Abschnitt aufgeführt.

     private fun performSuperResolution(ortSession: OrtSession) {
         var superResPerformer = SuperResPerformer()
         var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession)
         updateUI(result);
     }   
    
  9. TAG-Objekt hinzufügen

    companion object {
        const val TAG = "ORTSuperResolution"
    }
    

Modellinferenzklassen-Code

Erstellen Sie eine Datei namens SuperResPerformer.kt und fügen Sie ihr die folgenden Codeausschnitte hinzu.

  1. Importe hinzufügen

    import ai.onnxruntime.OnnxJavaType
    import ai.onnxruntime.OrtSession
    import ai.onnxruntime.OnnxTensor
    import ai.onnxruntime.OrtEnvironment
    import android.graphics.Bitmap
    import android.graphics.BitmapFactory
    import java.io.InputStream
    import java.nio.ByteBuffer
    import java.util.*
    
  2. Ergebnisklasse erstellen

    internal data class Result(
        var outputBitmap: Bitmap? = null
    ) {}
    
  3. Super-Resolution-Performer-Klasse erstellen

    Diese Klasse und ihre Hauptfunktion upscale sind die Orte, an denen die meisten Aufrufe an ONNX Runtime erfolgen.

    • Die OrtEnvironment Singleton verwaltet Eigenschaften der Umgebung und konfigurierte Protokollierungsstufen.
    • OnnxTensor.createTensor() wird verwendet, um einen Tensor aus den Eingabebild-Bytes zu erstellen, der als Eingabe für das Modell geeignet ist.
    • OnnxJavaType.UINT8 ist der Datentyp des ByteBuffers des Eingabetensors.
    • OrtSession.run() führt die Inferenz (Vorhersage) auf dem Modell aus, um das Ausgabe-Upscaled-Bild zu erhalten.
    internal class SuperResPerformer(
    ) {
    
        fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {
            var result = Result()
    
            // Step 1: convert image into byte array (raw image bytes)
            val rawImageBytes = inputStream.readBytes()
    
            // Step 2: get the shape of the byte array and make ort tensor
            val shape = longArrayOf(rawImageBytes.size.toLong())
    
            val inputTensor = OnnxTensor.createTensor(
                ortEnv,
                ByteBuffer.wrap(rawImageBytes),
                shape,
                OnnxJavaType.UINT8
            )
            inputTensor.use {
                // Step 3: call ort inferenceSession run
                val output = ortSession.run(Collections.singletonMap("image", inputTensor))
    
                // Step 4: output analysis
                output.use {
                    val rawOutput = (output?.get(0)?.value) as ByteArray
                    val outputImageBitmap =
                        byteArrayToBitmap(rawOutput)
    
                    // Step 5: set output result
                    result.outputBitmap = outputImageBitmap
                }
            }
            return result
        }
    

App erstellen und ausführen

Innerhalb von Android Studio

  • Wählen Sie Build -> Projekt erstellen.
  • Ausführen -> App.

Die App wird im Geräteemulator ausgeführt. Verbinden Sie Ihr Android-Gerät, um die App auf dem Gerät auszuführen.

iOS-App

Voraussetzungen

  • Installieren Sie Xcode 13.0 und höher (vorzugsweise die neueste Version).
  • Ein iOS-Gerät oder ein iOS-Simulator
  • Xcode-Befehlszeilentools xcode-select --install
  • CocoaPods sudo gem install cocoapods
  • Eine gültige Apple Developer ID (falls Sie beabsichtigen, auf einem Gerät auszuführen).

Beispielcode

Den vollständigen Quellcode für die iOS-Super-Resolution-App finden Sie auf GitHub.

Um die App aus dem Quellcode auszuführen:

  1. Klonen Sie das onnxruntime-inference-examples Repository.

    git clone https://github.com/microsoft/onnxruntime-inference-examples
    cd onnxruntime-inference-examples/mobile/examples/super_resolution/ios
    
  2. Installieren Sie die erforderlichen Pod-Dateien.

    pod install
    
  3. Öffnen Sie die generierte Datei ORTSuperResolution.xcworkspace in Xcode.

    (Optional: Nur erforderlich, wenn Sie auf einem Gerät ausführen) Wählen Sie Ihr Entwicklungsteam aus.

  4. Führen Sie die Anwendung aus.

    Verbinden Sie Ihr iOS-Gerät oder Ihren Simulator, erstellen und führen Sie die App aus.

    Klicken Sie auf die Schaltfläche Perform Super Resolution, um die App in Aktion zu sehen.

Um die App Schritt für Schritt zu entwickeln, folgen Sie den folgenden Abschnitten.

Code von Grund auf

Projekt erstellen

Erstellen Sie in Xcode ein neues Projekt mit der APP-Vorlage.

Abhängigkeiten

Installieren Sie die folgenden Pods.

  # Pods for OrtSuperResolution
  pod 'onnxruntime-c'
  
  # Pre-release version pods
  pod 'onnxruntime-extensions-c', '0.5.0-dev+261962.e3663fb'

Projektressourcen

  1. Fügen Sie die Modelldatei zum Projekt hinzu.

    Kopieren Sie die zu Beginn dieses Tutorials generierte Modelldatei in das Stammverzeichnis des Projektordners.

  2. Fügen Sie das Testbild als Asset hinzu.

    Kopieren Sie das Bild, auf das Sie Super Resolution anwenden möchten, in das Stammverzeichnis des Projektordners.

Haupt-App

Öffnen Sie die Datei namens ORTSuperResolutionApp.swift und fügen Sie den folgenden Code hinzu.

import SwiftUI

@main
struct ORTSuperResolutionApp: App {
    var body: some Scene {
        WindowGroup {
            ContentView()
        }
    }
}

Ansichts-Inhalt

Öffnen Sie die Datei namens ContentView.swift und fügen Sie den folgenden Code hinzu.

import SwiftUI

struct ContentView: View {
    @State private var performSuperRes = false
    
    func runOrtSuperResolution() -> UIImage? {
        do {
            let outputImage = try ORTSuperResolutionPerformer.performSuperResolution()
            return outputImage
        } catch let error as NSError {
            print("Error: \(error.localizedDescription)")
            return nil
        }
    }
    
    var body: some View {
        ScrollView {
            VStack {
                VStack {
                    Text("ORTSuperResolution").font(.title).bold()
                        .frame(width: 400, height: 80)
                        .border(Color.purple, width: 4)
                        .background(Color.purple)
                    
                    Text("Input low resolution image: ").frame(width: 350, height: 40, alignment:.leading)
                    
                    Image("cat_224x224").frame(width: 250, height: 250)
                    
                    Button("Perform Super Resolution") {
                        performSuperRes.toggle()
                    }
                    
                    if performSuperRes {
                        Text("Output high resolution image: ").frame(width: 350, height: 40, alignment:.leading)
                        
                        if let outputImage = runOrtSuperResolution() {
                            Image(uiImage: outputImage)
                        } else {
                            Text("Unable to perform super resolution. ").frame(width: 350, height: 40, alignment:.leading)
                        }
                    }
                    Spacer()
                }
            }
            .padding()
        }
    }
}

struct ContentView_Previews: PreviewProvider {
    static var previews: some View {
        ContentView()
    }
}

Swift / Objective-C-Brückenkopf

Erstellen Sie eine Datei namens ORTSuperResolution-Bridging-Header.h und fügen Sie die folgende Importanweisung hinzu.

#import "ORTSuperResolutionPerformer.h"

Super-Resolution-Code

  1. Erstellen Sie eine Datei namens ORTSuperResolutionPerformer.h und fügen Sie den folgenden Code hinzu.

    #ifndef ORTSuperResolutionPerformer_h
    #define ORTSuperResolutionPerformer_h
    
    #import <Foundation/Foundation.h>
    #import <UIKit/UIKit.h>
    
    NS_ASSUME_NONNULL_BEGIN
    
    @interface ORTSuperResolutionPerformer : NSObject
    
    + (nullable UIImage*)performSuperResolutionWithError:(NSError**)error;
    
    @end
    
    NS_ASSUME_NONNULL_END
    
    #endif
    
  2. Erstellen Sie eine Datei namens ORTSuperResolutionPerformer.mm und fügen Sie den folgenden Code hinzu.

     #import "ORTSuperResolutionPerformer.h"
     #import <Foundation/Foundation.h>
     #import <UIKit/UIKit.h>
    
     #include <array>
     #include <cstdint>
     #include <stdexcept>
     #include <string>
     #include <vector>
    
     #include <onnxruntime_cxx_api.h>
     #include <onnxruntime_extensions.h>
    
    
     @implementation ORTSuperResolutionPerformer
    
     + (nullable UIImage*)performSuperResolutionWithError:(NSError **)error {
            
         UIImage* output_image = nil;
            
         try {
                
             // Register custom ops
                
             const auto ort_log_level = ORT_LOGGING_LEVEL_INFO;
             auto ort_env = Ort::Env(ort_log_level, "ORTSuperResolution");
             auto session_options = Ort::SessionOptions();
                
             if (RegisterCustomOps(session_options, OrtGetApiBase()) != nullptr) {
                 throw std::runtime_error("RegisterCustomOps failed");
             }
                
             // Step 1: Load model
                
             NSString *model_path = [NSBundle.mainBundle pathForResource:@"pt_super_resolution_with_pre_post_processing_opset16"
                                                                 ofType:@"onnx"];
             if (model_path == nullptr) {
                 throw std::runtime_error("Failed to get model path");
             }
                
             // Step 2: Create Ort Inference Session
                
             auto sess = Ort::Session(ort_env, [model_path UTF8String], session_options);
                
             // Read input image
                
             // note: need to set Xcode settings to prevent it from messing with PNG files:
             // in "Build Settings":
             // - set "Compress PNG Files" to "No"
             // - set "Remove Text Metadata From PNG Files" to "No"
             NSString *input_image_path =
             [NSBundle.mainBundle pathForResource:@"cat_224x224" ofType:@"png"];
             if (input_image_path == nullptr) {
                 throw std::runtime_error("Failed to get image path");
             }
                
             // Step 3: Prepare input tensors and input/output names
                
             NSMutableData *input_data =
             [NSMutableData dataWithContentsOfFile:input_image_path];
             const int64_t input_data_length = input_data.length;
             const auto memoryInfo =
             Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
                
             const auto input_tensor = Ort::Value::CreateTensor(memoryInfo, [input_data mutableBytes], input_data_length,
                                                             &input_data_length, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
                
             constexpr auto input_names = std::array{"image"};
             constexpr auto output_names = std::array{"image_out"};
                
             // Step 4: Call inference session run
                
             const auto outputs = sess.Run(Ort::RunOptions(), input_names.data(),
                                         &input_tensor, 1, output_names.data(), 1);
             if (outputs.size() != 1) {
                 throw std::runtime_error("Unexpected number of outputs");
             }
                
             // Step 5: Analyze model outputs
                
             const auto &output_tensor = outputs.front();
             const auto output_type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo();
             const auto output_shape = output_type_and_shape_info.GetShape();
                
             if (const auto output_element_type =
                 output_type_and_shape_info.GetElementType();
                 output_element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
                 throw std::runtime_error("Unexpected output element type");
             }
                
             const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>();
                
             // Step 6: Convert raw bytes into NSData and return as displayable UIImage
                
             NSData *output_data = [NSData dataWithBytes:output_data_raw length:(output_shape[0])];
             output_image = [UIImage imageWithData:output_data];
                
         } catch (std::exception &e) {
             NSLog(@"%s error: %s", __FUNCTION__, e.what());
                
             static NSString *const kErrorDomain = @"ORTSuperResolution";
             constexpr NSInteger kErrorCode = 0;
             if (error) {
                 NSString *description =
                 [NSString stringWithCString:e.what() encoding:NSASCIIStringEncoding];
                 *error =
                 [NSError errorWithDomain:kErrorDomain
                                     code:kErrorCode
                                 userInfo:@{NSLocalizedDescriptionKey : description}];
             }
             return nullptr;
         }
            
         if (error) {
             *error = nullptr;
         }
         return output_image;
     }
    
     @end
    

App erstellen und ausführen

Wählen Sie in Xcode das Dreieckssymbol zum Erstellen aus, um die App zu erstellen und auszuführen!

Ressourcen

Original PyTorch Tutorial