Klasse OrtTrainingSession
- java.lang.Object
-
- ai.onnxruntime.OrtTrainingSession
-
- Alle implementierten Schnittstellen
java.lang.AutoCloseable
public final class OrtTrainingSession extends java.lang.Object implements java.lang.AutoCloseableUmschließt ein ONNX-Trainingsmodell und ermöglicht Trainings- und Inferenzaufrufe.Ermöglicht die Inspektion der Eingabe- und Ausgabeknoten des Modells. Erzeugt von einer
OrtEnvironment.Die meisten Instanzmethoden werfen
IllegalStateException, wenn die Sitzung geschlossen ist und die Methoden aufgerufen werden.
-
-
Zusammenfassung der Methoden
Alle Methoden Statische Methoden Instanzmethoden Konkrete Methoden Modifikator und Typ Methode Beschreibung voidaddProperty(java.lang.String name, float value)Fügt eine Float-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.voidaddProperty(java.lang.String name, int value)Fügt eine Integer-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.voidaddProperty(java.lang.String name, java.lang.String value)Fügt eine String-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.voidclose()OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.voidexportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames)Exportiert das Evaluationsmodell als ein für die Inferenz geeignetes Modell, wobei die gewünschten Knoten als Ausgabeknoten festgelegt werden.java.util.Set<java.lang.String>getEvalInputNames()Gibt ein geordnetes Set der Eingabenamen des Evaluationsmodells zurück.java.util.Set<java.lang.String>getEvalOutputNames()Gibt ein geordnetes Set der Ausgabenamen des Evaluationsmodells zurück.floatgetFloatProperty(java.lang.String name)Ruft eine Float-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.intgetIntProperty(java.lang.String name)Ruft eine Integer-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.floatgetLearningRate()Ruft die aktuelle Lernrate für diese Trainingssitzung ab.java.lang.StringgetStringProperty(java.lang.String name)Ruft eine String-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.java.util.Set<java.lang.String>getTrainInputNames()Gibt ein geordnetes Set der Eingabenamen des Trainingsmodells zurück.java.util.Set<java.lang.String>getTrainOutputNames()Gibt ein geordnetes Set der Ausgabenamen des Trainingsmodells zurück.voidlazyResetGrad()Stellt sicher, dass die Gradienten vor dem nächsten Aufruf vontrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)auf Null zurückgesetzt werden.voidoptimizerStep()Wendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.voidoptimizerStep(OrtSession.RunOptions runOptions)Wendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.voidregisterLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate)Registriert einen linearen Lernraten-Scheduler mit linearer Aufwärmphase.voidsaveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer)Speichert den Zustand der Trainingssitzung im bereitgestellten Checkpoint-Verzeichnis.voidschedulerStep()Aktualisiert die Lernrate basierend auf dem registrierten Lernraten-Scheduler.voidsetLearningRate(float learningRate)Legt die Lernrate für die Trainingssitzung fest.static voidsetSeed(long seed)Legt den von ONNX Runtime verwendeten RNG-Seed fest.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.
-
-
-
Detail der Methoden
-
getTrainInputNames
public java.util.Set<java.lang.String> getTrainInputNames()
Gibt ein geordnetes Set der Eingabenamen des Trainingsmodells zurück.- Rückgabe
- Die Trainingseingaben.
-
getTrainOutputNames
public java.util.Set<java.lang.String> getTrainOutputNames()
Gibt ein geordnetes Set der Ausgabenamen des Trainingsmodells zurück.- Rückgabe
- Die Trainingsausgaben.
-
getEvalInputNames
public java.util.Set<java.lang.String> getEvalInputNames()
Gibt ein geordnetes Set der Eingabenamen des Evaluationsmodells zurück.- Rückgabe
- Die Auswertungseingaben.
-
getEvalOutputNames
public java.util.Set<java.lang.String> getEvalOutputNames()
Gibt ein geordnetes Set der Ausgabenamen des Evaluationsmodells zurück.- Rückgabe
- Die Auswertungsausgaben.
-
addProperty
public void addProperty(java.lang.String name, float value) throws OrtExceptionFügt eine Float-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.- Parameter
name- Der Name der Eigenschaft.value- Der Wert der Eigenschaft.- Wirft
OrtException- Wenn der Aufruf fehlschlägt.
-
addProperty
public void addProperty(java.lang.String name, int value) throws OrtExceptionFügt eine Integer-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.- Parameter
name- Der Name der Eigenschaft.value- Der Wert der Eigenschaft.- Wirft
OrtException- Wenn der Aufruf fehlschlägt.
-
addProperty
public void addProperty(java.lang.String name, java.lang.String value) throws OrtExceptionFügt eine String-Eigenschaft zum Checkpoint dieser Trainingssitzung hinzu.- Parameter
name- Der Name der Eigenschaft.value- Der Wert der Eigenschaft.- Wirft
OrtException- Wenn der Aufruf fehlschlägt.
-
getFloatProperty
public float getFloatProperty(java.lang.String name) throws OrtExceptionRuft eine Float-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.- Parameter
name- Der Name der Eigenschaft.- Rückgabe
- Der Wert der Eigenschaft.
- Wirft
OrtException- Wenn die Eigenschaft nicht existiert oder vom falschen Typ ist.
-
getIntProperty
public int getIntProperty(java.lang.String name) throws OrtExceptionRuft eine Integer-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.- Parameter
name- Der Name der Eigenschaft.- Rückgabe
- Der Wert der Eigenschaft.
- Wirft
OrtException- Wenn die Eigenschaft nicht existiert oder vom falschen Typ ist.
-
getStringProperty
public java.lang.String getStringProperty(java.lang.String name) throws OrtExceptionRuft eine String-Eigenschaft vom Checkpoint dieser Trainingssitzung ab.- Parameter
name- Der Name der Eigenschaft.- Rückgabe
- Der Wert der Eigenschaft.
- Wirft
OrtException- Wenn die Eigenschaft nicht existiert oder vom falschen Typ ist.
-
close
public void close()
- Spezifiziert von
closein Schnittstellejava.lang.AutoCloseable
-
saveCheckpoint
public void saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer) throws OrtExceptionSpeichert den Zustand der Trainingssitzung im bereitgestellten Checkpoint-Verzeichnis.- Parameter
outputPath- Pfad zu einem Checkpoint-Verzeichnis.saveOptimizer- Sollen die Optimierungszustände gespeichert werden.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
lazyResetGrad
public void lazyResetGrad() throws OrtExceptionStellt sicher, dass die Gradienten vor dem nächsten Aufruf vontrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)auf Null zurückgesetzt werden.Hinweis: Dies ist ein Lazy-Aufruf, die Gradienten werden als Teil der Ausführung des nächsten
trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)gelöscht und nicht vorher.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
setSeed
public static void setSeed(long seed) throws OrtExceptionLegt den von ONNX Runtime verwendeten RNG-Seed fest.Hinweis: Diese Einstellung ist global für alle OrtTrainingSession-Instanzen.
- Parameter
seed- Der RNG-Seed.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.- Parameter
inputs- Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).- Rückgabe
- Alle von Schritt trainStep erzeugten Ausgaben.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.- Parameter
inputs- Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).runOptions- Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.- Rückgabe
- Alle von Schritt trainStep erzeugten Ausgaben.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.- Parameter
inputs- Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).requestedOutputs- Die angeforderten Ausgaben.- Rückgabe
- Angeforderte Ausgaben, die von Schritt trainStep erzeugt wurden.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map.
Hinweis: Angeheftete Ausgaben gehören nicht zum
OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.- Parameter
inputs- Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).pinnedOutputs- Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.- Rückgabe
- Angeforderte Ausgaben, die von Schritt trainStep erzeugt wurden.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
Führt einen einzelnen Trainingsschritt durch und akkumuliert die Gradienten.Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map, mit angehefteten Ausgaben zuerst, dann angeforderten Ausgaben. Eine
IllegalArgumentExceptionwird ausgelöst, wenn derselbe Ausgabenname sowohl in den angeforderten als auch in den angehefteten Ausgaben vorkommt.Hinweis: Angeheftete Ausgaben gehören nicht zum
OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.- Parameter
inputs- Die Eingaben (müssen sowohl die Merkmale als auch das Ziel enthalten).requestedOutputs- Die angeforderten Ausgaben, die von ORT zugewiesen werden.pinnedOutputs- Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.runOptions- Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.- Rückgabe
- Angeforderte Ausgaben, die von Schritt trainStep erzeugt wurden.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.- Parameter
inputs- Die Modelleingaben.- Rückgabe
- Alle Modellausgaben.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.- Parameter
inputs- Die Modelleingaben.runOptions- Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.- Rückgabe
- Alle Modellausgaben.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.- Parameter
inputs- Die Modelleingaben.requestedOutputs- Die Namen der angeforderten Ausgaben.- Rückgabe
- Die angeforderten Ausgaben.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map.
Hinweis: Angeheftete Ausgaben gehören nicht zum
OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.- Parameter
inputs- Die zu bewertenden Eingaben.pinnedOutputs- Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.- Rückgabe
- Die angeforderten Ausgaben.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
Führt einen einzelnen Evaluationsschritt mit den bereitgestellten Eingaben durch.Die Ausgaben sind sortiert nach der Reihenfolge der Traversierung der bereitgestellten Map, mit angehefteten Ausgaben zuerst, dann angeforderten Ausgaben. Eine
IllegalArgumentExceptionwird ausgelöst, wenn derselbe Ausgabenname sowohl in den angeforderten als auch in den angehefteten Ausgaben vorkommt.Hinweis: Angeheftete Ausgaben gehören nicht zum
OrtSession.Result-Objekt und werden **nicht** geschlossen, wenn das Ergebnisobjekt geschlossen wird.- Parameter
inputs- Die zu bewertenden Eingaben.requestedOutputs- Die angeforderten Ausgaben, die von ORT zugewiesen werden.pinnedOutputs- Die angeforderten Ausgaben, die der Benutzer zugewiesen hat.runOptions- Ausführungsoptionen zur Steuerung dieses spezifischen Aufrufs.- Rückgabe
- Die angeforderten Ausgaben.
- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
setLearningRate
public void setLearningRate(float learningRate) throws OrtExceptionLegt die Lernrate für die Trainingssitzung fest.Sollte nur verwendet werden, wenn kein Lernraten-Scheduler in der Sitzung vorhanden ist. Wird nicht verwendet, um die anfängliche Lernrate für LR-Scheduler festzulegen.
- Parameter
learningRate- Die Lernrate.- Wirft
OrtException- Wenn der Aufruf fehlschlägt.
-
getLearningRate
public float getLearningRate() throws OrtExceptionRuft die aktuelle Lernrate für diese Trainingssitzung ab.- Rückgabe
- Die aktuelle Lernrate.
- Wirft
OrtException- Wenn der Aufruf fehlschlägt.
-
optimizerStep
public void optimizerStep() throws OrtExceptionWendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
optimizerStep
public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException
Wendet die Gradientenaktualisierungen auf die trainierbaren Parameter mit dem Optimierungsmodell an.Die Ausführungsoptionen können zur Steuerung der Protokollierung und zur vorzeitigen Beendigung des Aufrufs verwendet werden.
- Parameter
runOptions- Optionen zur Steuerung der Modellausführung.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
registerLinearLRScheduler
public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtExceptionRegistriert einen linearen Lernraten-Scheduler mit linearer Aufwärmphase.- Parameter
warmupSteps- Die Anzahl der Schritte, um die Lernrate von Null aufinitialLearningRatezu erhöhen.totalSteps- Die Gesamtzahl der Schritte, über die dieser Scheduler läuft.initialLearningRate- Die maximale Lernrate.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
schedulerStep
public void schedulerStep() throws OrtExceptionAktualisiert die Lernrate basierend auf dem registrierten Lernraten-Scheduler.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
exportModelForInference
public void exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames) throws OrtExceptionExportiert das Evaluationsmodell als ein für die Inferenz geeignetes Modell, wobei die gewünschten Knoten als Ausgabeknoten festgelegt werden.Hinweis: Diese Methode lädt das Evaluationsmodell erneut von dem Pfad, der für die Trainingssitzung bereitgestellt wurde, und dieser Pfad muss immer noch gültig sein.
- Parameter
outputPath- Der Pfad, auf den das Inferenzmodell geschrieben werden soll.outputNames- Die Namen der Ausgabeknoten.- Wirft
OrtException- Wenn der native Aufruf fehlgeschlagen ist.
-
-