Das Modell auf dem Gerät trainieren#
Sobald die Trainingsartefakte generiert sind, kann das Modell mit der Python-API von onnxruntime-Training auf dem Gerät trainiert werden.
Die erwarteten Trainingsartefakte sind
Das ONNX-Trainingsmodell
Der Checkpoint-Status
Das ONNX-Optimierungsmodell
Das Auswertungs-ONNX-Modell (optional)
Beispielverwendung
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
# Create the module
module = Module(path_to_the_training_model,
state,
path_to_the_eval_model,
device="cpu")
optimizer = Optimizer(path_to_the_optimizer_model, module)
# Training loop
for ...:
module.train()
training_loss = module(...)
optimizer.step()
module.lazy_reset_grad()
# Eval
module.eval()
eval_loss = module(...)
# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
- class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[Quelle]#
Basiert auf:
objectKlasse, die einen Modellparameter repräsentiert
Diese Klasse repräsentiert einen Modellparameter und bietet Zugriff auf seine Daten, Gradienten und andere Eigenschaften. Diese Klasse wird nicht erwartet, direkt instanziiert zu werden. Stattdessen wird sie vom CheckpointState-Objekt zurückgegeben.
- Parameter:
parameter – Das C.Parameter-Objekt, das die zugrunde liegenden Parameterdaten enthält.
state – Das C.CheckpointState-Objekt, das den zugrunde liegenden Sitzungsstatus enthält.
- class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[Quelle]#
Basiert auf:
objectKlasse, die alle Modellparameter enthält
Diese Klasse enthält alle Modellparameter und bietet Zugriff darauf. Diese Klasse wird nicht erwartet, direkt instanziiert zu werden. Stattdessen wird sie über das Parameter-Attribut von CheckpointState zurückgegeben. Diese Klasse verhält sich wie ein Wörterbuch und bietet Zugriff auf die Parameter nach Namen.
- Parameter:
state – Das C.CheckpointState-Objekt, das den zugrunde liegenden Sitzungsstatus enthält.
- __getitem__(name: str) Parameter[Quelle]#
Ruft den Parameter ab, der mit dem gegebenen Namen verknüpft ist
Sucht den Namen in den Parametern des Checkpoint-Status.
- Parameter:
name – Der Name des Parameters
- Rückgabe:
Der Wert des Parameters
- Ausnahmen:
KeyError – Wenn der Parameter nicht gefunden wird
- __setitem__(name: str, value: ndarray) None[Quelle]#
Setzt den Parameterwert für den gegebenen Namen
Sucht den Namen in den Parametern des Checkpoint-Status. Wenn der Name in den Parametern gefunden wird, wird der Wert aktualisiert.
- Parameter:
name – Der Name des Parameters
value – Der Wert des Parameters als NumPy-Array
- Ausnahmen:
KeyError – Wenn der Parameter nicht gefunden wird
- class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[Quelle]#
Basiert auf:
object- __getitem__(name: str) int | float | str[Quelle]#
Ruft die Eigenschaft ab, die mit dem gegebenen Namen verknüpft ist
Sucht den Namen in den Eigenschaften des Checkpoint-Status.
- Parameter:
name – Der Name der Eigenschaft
- Rückgabe:
Der Wert der Eigenschaft
- Ausnahmen:
KeyError – Wenn die Eigenschaft nicht gefunden wird
- __setitem__(name: str, value: int | float | str) None[Quelle]#
Setzt den Eigenschaftswert für den gegebenen Namen
Sucht den Namen in den Eigenschaften des Checkpoint-Status. Der Wert wird hinzugefügt oder in den Eigenschaften aktualisiert.
- Parameter:
name – Der Name der Eigenschaft
value – Der Wert der Eigenschaft. Properties unterstützen nur int-, float- und str-Werte.
- class onnxruntime.training.api.CheckpointState(state: CheckpointState)[Quelle]#
Basiert auf:
objectKlasse, die den Status der Trainingssitzung enthält
Diese Klasse enthält alle Statusinformationen der Trainingssitzung, wie z. B. die Modellparameter, ihre Gradienten, den Optimierungsstatus und benutzerdefinierte Eigenschaften.
Um CheckpointState zu erstellen, verwenden Sie die Methode CheckpointState.load_checkpoint.
- Parameter:
state – Das C.Checkpoint-Statusobjekt, das den zugrunde liegenden Sitzungsstatus enthält.
- classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[Quelle]#
Lädt den Checkpoint-Status aus der Checkpoint-Datei
Die Checkpoint-Datei kann entweder der vollständige Checkpoint oder der nominelle Checkpoint sein.
- Parameter:
checkpoint_uri – Der Pfad zur Checkpoint-Datei.
- Rückgabe:
Das Checkpoint-Statusobjekt.
- Rückgabetyp:
- classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[Quelle]#
Speichert den Checkpoint-Status in der Checkpoint-Datei
- Parameter:
state – Das Checkpoint-Statusobjekt.
checkpoint_uri – Der Pfad zur Checkpoint-Datei.
include_optimizer_state – Wenn True, wird auch der Optimierungsstatus in die Checkpoint-Datei gespeichert.
- property parameters: Parameters#
Gibt die Modellparameter aus dem Checkpoint-Status zurück
- property properties: Properties#
Gibt die Eigenschaften aus dem Checkpoint-Status zurück
- class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[Quelle]#
Basiert auf:
objectTrainer-Klasse, die Trainings- und Auswertungsfunktionen für ONNX-Modelle bereitstellt.
Vor der Instanziierung der Module-Klasse wird erwartet, dass die Trainingsartefakte mit dem Dienstprogramm onnxruntime.training.artifacts.generate_artifacts generiert wurden.
- Die Trainingsartefakte umfassen
Das Trainingsmodell
Das Auswertungsmodell (optional)
Das Optimierungsmodell (optional)
Die Checkpoint-Datei
- training#
True, wenn sich das Modell im Trainingsmodus befindet, False, wenn es sich im Auswertungsmodus befindet.
- Typ:
- Parameter:
train_model_uri – Der Pfad zum Trainingsmodell.
state – Das Checkpoint-Statusobjekt.
eval_model_uri – Der Pfad zum Auswertungsmodell.
device – Das Gerät, auf dem das Modell ausgeführt werden soll. Standard: "cpu".
session_options – Die Sitzungsoptionen, die für das Modell verwendet werden sollen.
- __call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[Quelle]#
Ruft entweder den Trainings- oder den Auswertungsschritt des Modells auf.
- Parameter:
*user_inputs – Die Eingaben für das Modell. Die Benutzereingaben können entweder NumPy-Arrays oder OrtValues sein.
- Rückgabe:
Die Ausgaben des Modells.
- train(mode: bool = True) Module[Quelle]#
Setzt das Modul in den Trainingsmodus.
- Parameter:
mode – Ob das Modell in den Trainingsmodus (True) oder den Auswertungsmodus (False) gesetzt werden soll. Standard: True.
- Rückgabe:
self
- lazy_reset_grad()[Quelle]#
Setzt die Trainingsgradienten verzögert zurück.
Diese Funktion setzt den internen Status des Moduls so, dass die Modulgradienten unmittelbar vor der Berechnung neuer Gradienten bei der nächsten Ausführung von train() zurückgesetzt werden.
- get_contiguous_parameters(trainable_only: bool = False) OrtValue[Quelle]#
Erstellt einen zusammenhängenden Puffer der Trainingssitzungsparameter
- Parameter:
trainable_only – Wenn True, werden nur trainierbare Parameter berücksichtigt. Andernfalls werden alle Parameter berücksichtigt.
- Rückgabe:
Der zusammenhängende Puffer der Trainingssitzungsparameter.
- get_parameters_size(trainable_only: bool = True) int[Quelle]#
Gibt die Größe der Parameter zurück.
- Parameter:
trainable_only – Wenn True, werden nur trainierbare Parameter berücksichtigt. Andernfalls werden alle Parameter berücksichtigt.
- Rückgabe:
Die Anzahl der primitiven (z. B. Fließkomma-) Elemente in den Parametern.
- copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#
Kopiert den OrtValue-Puffer in die Trainingssitzungsparameter.
Falls das Modul aus einem nominalen Checkpoint geladen wurde, muss diese Funktion aufgerufen werden, um die aktualisierten Parameter auf den Checkpoint zu laden, um ihn zu vervollständigen.
- Parameter:
buffer – Der OrtValue-Puffer, der in die Trainingssitzungsparameter kopiert werden soll.
- export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None[source]#
Exportiert das Modell für die Inferenz.
Sobald das Training abgeschlossen ist, kann diese Funktion verwendet werden, um die trainingsspezifischen Knoten im ONNX-Modell zu entfernen. Insbesondere führt diese Funktion Folgendes aus:
Durchläuft den Trainingsgraphen und identifiziert Knoten, die die gegebenen Ausgabenamen erzeugen.
Entfernt alle nachfolgenden Knoten im Graphen, da sie für den Inferenzgraphen nicht relevant sind.
- Parameter:
inference_model_uri – Der Pfad zum Inferenzmodell.
graph_output_names – Die Liste der Ausgabenamen, die für die Inferenz erforderlich sind.
- class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#
Basiert auf:
objectKlasse, die Methoden zur Aktualisierung der Modellparameter basierend auf den berechneten Gradienten bereitstellt.
- Parameter:
optimizer_uri – Der Pfad zum Optimierermodell.
module – Das zu trainierende Modul.
- step() None[source]#
Aktualisiert die Modellparameter basierend auf den berechneten Gradienten.
Diese Methode aktualisiert die Modellparameter, indem sie einen Schritt in Richtung der berechneten Gradienten macht. Der verwendete Optimierer hängt vom bereitgestellten Optimierermodell ab.
- class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#
Basiert auf:
objectAktualisiert linear die Lernrate im Optimierer
Der lineare Lernraten-Scheduler verringert die Lernrate durch einen linear aktualisierten multiplikativen Faktor von der anfänglichen Lernrate, die in der Trainingssitzung festgelegt wurde, auf 0. Die Verringerung erfolgt nach der anfänglichen Aufwärmphase, in der die Lernrate linear von 0 auf die angegebene anfängliche Lernrate erhöht wird.
- Parameter:
optimizer – Der onnxruntime-Trainingsoptimierer des Benutzers
warmup_step_count – Die Anzahl der Schritte in der Aufwärmphase.
total_step_count – Die Gesamtzahl der Trainingsschritte.
initial_lr – Die anfängliche Lernrate.