Klasse OrtTrainingSession

  • Alle implementierten Schnittstellen
    java.lang.AutoCloseable

    public final class OrtTrainingSession
    extends java.lang.Object
    implements java.lang.AutoCloseable
    Umschließt ein ONNX-Trainingsmodell und ermöglicht Trainings- und Inferenzaufrufe.

    Ermöglicht die Inspektion der Eingabe- und Ausgabeknoten des Modells. Erzeugt von einer OrtEnvironment.

    Die meisten Instanzmethoden werfen IllegalStateException, wenn die Sitzung geschlossen ist und die Methoden aufgerufen werden.

    • Zusammenfassung der Methoden

      Alle Methoden Statische Methoden Instanzmethoden Konkrete Methoden 
      Modifikator und Typ Methode Beschreibung
      void addProperty​(java.lang.String name, float value)
      Fügt eine Float-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.
      void addProperty​(java.lang.String name, int value)
      Fügt eine Integer-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.
      void addProperty​(java.lang.String name, java.lang.String value)
      Fügt eine String-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.
      void close()  
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
      Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
      Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
      Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
      Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
      Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
      void exportModelForInference​(java.nio.file.Path outputPath, java.lang.String[] outputNames)
      Exportiert das Evaluationsmodell als ein für die Inferenz geeignetes Modell, wobei die gewünschten Knoten als Ausgabeknoten festgelegt werden.
      java.util.Set<java.lang.String> getEvalInputNames()
      Gibt ein geordnetes Set der Eingabenamen des Evaluationsmodells zurück.
      java.util.Set<java.lang.String> getEvalOutputNames()
      Gibt ein geordnetes Set der Ausgabenamen des Evaluationsmodells zurück.
      float getFloatProperty​(java.lang.String name)
      Ruft eine Float-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.
      int getIntProperty​(java.lang.String name)
      Ruft eine Integer-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.
      float getLearningRate()
      Ruft die aktuelle Lernrate für diese Trainingssitzung ab.
      java.lang.String getStringProperty​(java.lang.String name)
      Ruft eine String-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.
      java.util.Set<java.lang.String> getTrainInputNames()
      Gibt ein geordnetes Set der Eingabenamen des Trainingsmodells zurück.
      java.util.Set<java.lang.String> getTrainOutputNames()
      Gibt ein geordnetes Set der Ausgabenamen des Trainingsmodells zurück.
      void lazyResetGrad()
      Stellt sicher, dass die Gradienten vor dem nächsten Aufruf von trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>) auf Null zurückgesetzt werden.
      void optimizerStep()
      Wendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.
      void optimizerStep​(OrtSession.RunOptions runOptions)
      Wendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.
      void registerLinearLRScheduler​(long warmupSteps, long totalSteps, float initialLearningRate)
      Registriert einen linearen Lernraten-Scheduler mit linearer Aufwärmphase.
      void saveCheckpoint​(java.nio.file.Path outputPath, boolean saveOptimizer)
      Speichert den Zustand der Trainingssitzung im bereitgestellten Checkpoint-Verzeichnis.
      void schedulerStep()
      Aktualisiert die Lernrate basierend auf dem registrierten Lernraten-Scheduler.
      void setLearningRate​(float learningRate)
      Legt die Lernrate für die Trainingssitzung fest.
      static void setSeed​(long seed)
      Legt den von ONNX Runtime verwendeten RNG-Seed fest.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
      Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
      Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
      Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
      Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
      Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
      • Methoden geerbt von Klasse java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Detail der Methoden

      • getTrainInputNames

        public java.util.Set<java.lang.String> getTrainInputNames()
        Gibt ein geordnetes Set der Eingabenamen des Trainingsmodells zurück.
        Rückgabe
        Die Trainingseingaben.
      • getTrainOutputNames

        public java.util.Set<java.lang.String> getTrainOutputNames()
        Gibt ein geordnetes Set der Ausgabenamen des Trainingsmodells zurück.
        Rückgabe
        Die Trainingsausgaben.
      • getEvalInputNames

        public java.util.Set<java.lang.String> getEvalInputNames()
        Gibt ein geordnetes Set der Eingabenamen des Evaluationsmodells zurück.
        Rückgabe
        Die Auswertungseingaben.
      • getEvalOutputNames

        public java.util.Set<java.lang.String> getEvalOutputNames()
        Gibt ein geordnetes Set der Ausgabenamen des Evaluationsmodells zurück.
        Rückgabe
        Die Auswertungsausgaben.
      • addProperty

        public void addProperty​(java.lang.String name,
                                float value)
                         throws OrtException
        Fügt eine Float-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.
        Parameter
        name - Der Name der Eigenschaft.
        value - Der Wert der Eigenschaft.
        Wirft
        OrtException - Wenn der Aufruf fehlschlägt.
      • addProperty

        public void addProperty​(java.lang.String name,
                                int value)
                         throws OrtException
        Fügt eine Integer-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.
        Parameter
        name - Der Name der Eigenschaft.
        value - Der Wert der Eigenschaft.
        Wirft
        OrtException - Wenn der Aufruf fehlschlägt.
      • addProperty

        public void addProperty​(java.lang.String name,
                                java.lang.String value)
                         throws OrtException
        Fügt eine String-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.
        Parameter
        name - Der Name der Eigenschaft.
        value - Der Wert der Eigenschaft.
        Wirft
        OrtException - Wenn der Aufruf fehlschlägt.
      • getFloatProperty

        public float getFloatProperty​(java.lang.String name)
                               throws OrtException
        Ruft eine Float-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.
        Parameter
        name - Der Name der Eigenschaft.
        Rückgabe
        Der Wert der Eigenschaft.
        Wirft
        OrtException - Wenn die Eigenschaft nicht existiert oder vom falschen Typ ist.
      • getIntProperty

        public int getIntProperty​(java.lang.String name)
                           throws OrtException
        Ruft eine Integer-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.
        Parameter
        name - Der Name der Eigenschaft.
        Rückgabe
        Der Wert der Eigenschaft.
        Wirft
        OrtException - Wenn die Eigenschaft nicht existiert oder vom falschen Typ ist.
      • getStringProperty

        public java.lang.String getStringProperty​(java.lang.String name)
                                           throws OrtException
        Ruft eine String-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.
        Parameter
        name - Der Name der Eigenschaft.
        Rückgabe
        Der Wert der Eigenschaft.
        Wirft
        OrtException - Wenn die Eigenschaft nicht existiert oder vom falschen Typ ist.
      • close

        public void close()
        Spezifiziert von
        close in Schnittstelle java.lang.AutoCloseable
      • saveCheckpoint

        public void saveCheckpoint​(java.nio.file.Path outputPath,
                                   boolean saveOptimizer)
                            throws OrtException
        Speichert den Zustand der Trainingssitzung im bereitgestellten Checkpoint-Verzeichnis.
        Parameter
        outputPath - Pfad zu einem Checkpoint-Verzeichnis.
        saveOptimizer - Sollen die Optimierungszustände gespeichert werden.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • setSeed

        public static void setSeed​(long seed)
                            throws OrtException
        Legt den von ONNX Runtime verwendeten RNG-Seed fest.

        Hinweis: Diese Einstellung ist global für alle OrtTrainingSession-Instanzen.

        Parameter
        seed - Der RNG-Seed.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                    throws OrtException
        Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
        Parameter
        inputs - Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).
        Rückgabe
        Alle von Schritt trainStep erzeugten Ausgaben.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
        Parameter
        inputs - Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).
        runOptions - Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.
        Rückgabe
        Alle von Schritt trainStep erzeugten Ausgaben.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs)
                                    throws OrtException
        Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
        Parameter
        inputs - Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).
        requestedOutputs - Die angeforderten Ausgaben.
        Rückgabe
        Angeforderte Ausgaben, die von Schritt trainStep erzeugt wurden.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                    throws OrtException
        Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.

        Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map.

        Hinweis: Angeheftete Ausgaben gehören nicht zum OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.

        Parameter
        inputs - Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).
        pinnedOutputs - Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.
        Rückgabe
        Angeforderte Ausgaben, die von Schritt trainStep erzeugt wurden.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.

        Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map, mit angehefteten Ausgaben zuerst, dann angeforderten Ausgaben. Eine IllegalArgumentException wird ausgelöst, wenn derselbe Ausgabenname sowohl in den angeforderten als auch in den angehefteten Ausgaben vorkommt.

        Hinweis: Angeheftete Ausgaben gehören nicht zum OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.

        Parameter
        inputs - Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).
        requestedOutputs - Die angeforderten Ausgaben, die von ORT zugewiesen werden.
        pinnedOutputs - Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.
        runOptions - Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.
        Rückgabe
        Angeforderte Ausgaben, die von Schritt trainStep erzeugt wurden.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                   throws OrtException
        Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
        Parameter
        inputs - Die Modelleingaben.
        Rückgabe
        Alle Modellausgaben.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
        Parameter
        inputs - Die Modelleingaben.
        runOptions - Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.
        Rückgabe
        Alle Modellausgaben.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs)
                                   throws OrtException
        Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.
        Parameter
        inputs - Die Modelleingaben.
        requestedOutputs - Die Namen der angeforderten Ausgaben.
        Rückgabe
        Die angeforderten Ausgaben.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                   throws OrtException
        Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.

        Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map.

        Hinweis: Angeheftete Ausgaben gehören nicht zum OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.

        Parameter
        inputs - Die zu bewertenden Eingaben.
        pinnedOutputs - Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.
        Rückgabe
        Die angeforderten Ausgaben.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.

        Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map, mit angehefteten Ausgaben zuerst, dann angeforderten Ausgaben. Eine IllegalArgumentException wird ausgelöst, wenn derselbe Ausgabenname sowohl in den angeforderten als auch in den angehefteten Ausgaben vorkommt.

        Hinweis: Angeheftete Ausgaben gehören nicht zum OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.

        Parameter
        inputs - Die zu bewertenden Eingaben.
        requestedOutputs - Die angeforderten Ausgaben, die von ORT zugewiesen werden.
        pinnedOutputs - Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.
        runOptions - Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.
        Rückgabe
        Die angeforderten Ausgaben.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • setLearningRate

        public void setLearningRate​(float learningRate)
                             throws OrtException
        Legt die Lernrate für die Trainingssitzung fest.

        Sollte nur verwendet werden, wenn kein Lernraten-Scheduler in der Sitzung vorhanden ist. Wird nicht verwendet, um die anfängliche Lernrate für LR-Scheduler festzulegen.

        Parameter
        learningRate - Die Lernrate.
        Wirft
        OrtException - Wenn der Aufruf fehlschlägt.
      • getLearningRate

        public float getLearningRate()
                              throws OrtException
        Ruft die aktuelle Lernrate für diese Trainingssitzung ab.
        Rückgabe
        Die aktuelle Lernrate.
        Wirft
        OrtException - Wenn der Aufruf fehlschlägt.
      • optimizerStep

        public void optimizerStep()
                           throws OrtException
        Wendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • optimizerStep

        public void optimizerStep​(OrtSession.RunOptions runOptions)
                           throws OrtException
        Wendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.

        Die Ausführungsoptionen können zur Steuerung der Protokollierung und zur vorzeitigen Beendigung des Aufrufs verwendet werden.

        Parameter
        runOptions - Optionen zur Steuerung der Modellausführung.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • registerLinearLRScheduler

        public void registerLinearLRScheduler​(long warmupSteps,
                                              long totalSteps,
                                              float initialLearningRate)
                                       throws OrtException
        Registriert einen linearen Lernraten-Scheduler mit linearer Aufwärmphase.
        Parameter
        warmupSteps - Die Anzahl der Schritte, um die Lernrate von Null auf initialLearningRate zu erhöhen.
        totalSteps - Die Gesamtzahl der Schritte, über die dieser Scheduler läuft.
        initialLearningRate - Die maximale Lernrate.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • schedulerStep

        public void schedulerStep()
                           throws OrtException
        Aktualisiert die Lernrate basierend auf dem registrierten Lernraten-Scheduler.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.
      • exportModelForInference

        public void exportModelForInference​(java.nio.file.Path outputPath,
                                            java.lang.String[] outputNames)
                                     throws OrtException
        Exportiert das Evaluationsmodell als ein für die Inferenz geeignetes Modell, wobei die gewünschten Knoten als Ausgabeknoten festgelegt werden.

        Hinweis: Diese Methode lädt das Evaluationsmodell erneut von dem Pfad, der für die Trainingssitzung bereitgestellt wurde, und dieser Pfad muss immer noch gültig sein.

        Parameter
        outputPath - Der Pfad, auf den das Inferenzmodell geschrieben werden soll.
        outputNames - Die Namen der Ausgabeknoten.
        Wirft
        OrtException - Wenn der native Aufruf fehlgeschlagen ist.