Das Modell auf dem Gerät trainieren#

Sobald die Trainingsartefakte generiert sind, kann das Modell mit der Python-API von onnxruntime-Training auf dem Gerät trainiert werden.

Die erwarteten Trainingsartefakte sind

  1. Das ONNX-Trainingsmodell

  2. Der Checkpoint-Status

  3. Das ONNX-Optimierungsmodell

  4. Das Auswertungs-ONNX-Modell (optional)

Beispielverwendung

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[Quelle]#

Basiert auf: object

Klasse, die einen Modellparameter repräsentiert

Diese Klasse repräsentiert einen Modellparameter und bietet Zugriff auf seine Daten, Gradienten und andere Eigenschaften. Diese Klasse wird nicht erwartet, direkt instanziiert zu werden. Stattdessen wird sie vom CheckpointState-Objekt zurückgegeben.

Parameter:
  • parameter – Das C.Parameter-Objekt, das die zugrunde liegenden Parameterdaten enthält.

  • state – Das C.CheckpointState-Objekt, das den zugrunde liegenden Sitzungsstatus enthält.

property name: str#

Der Name des Parameters

property data: ndarray#

Die Daten des Parameters

property grad: ndarray#

Der Gradient des Parameters

property requires_grad: bool#

Ob der Parameter die Berechnung seines Gradienten erfordert oder nicht

__repr__() str[Quelle]#

Gibt eine String-Darstellung des Parameters zurück

class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[Quelle]#

Basiert auf: object

Klasse, die alle Modellparameter enthält

Diese Klasse enthält alle Modellparameter und bietet Zugriff darauf. Diese Klasse wird nicht erwartet, direkt instanziiert zu werden. Stattdessen wird sie über das Parameter-Attribut von CheckpointState zurückgegeben. Diese Klasse verhält sich wie ein Wörterbuch und bietet Zugriff auf die Parameter nach Namen.

Parameter:

state – Das C.CheckpointState-Objekt, das den zugrunde liegenden Sitzungsstatus enthält.

__getitem__(name: str) Parameter[Quelle]#

Ruft den Parameter ab, der mit dem gegebenen Namen verknüpft ist

Sucht den Namen in den Parametern des Checkpoint-Status.

Parameter:

name – Der Name des Parameters

Rückgabe:

Der Wert des Parameters

Ausnahmen:

KeyError – Wenn der Parameter nicht gefunden wird

__setitem__(name: str, value: ndarray) None[Quelle]#

Setzt den Parameterwert für den gegebenen Namen

Sucht den Namen in den Parametern des Checkpoint-Status. Wenn der Name in den Parametern gefunden wird, wird der Wert aktualisiert.

Parameter:
  • name – Der Name des Parameters

  • value – Der Wert des Parameters als NumPy-Array

Ausnahmen:

KeyError – Wenn der Parameter nicht gefunden wird

__contains__(name: str) bool[Quelle]#

Prüft, ob der Parameter im Status vorhanden ist

Parameter:

name – Der Name des Parameters

Rückgabe:

True, wenn der Name ein Parameter ist, sonst False

__iter__()[Quelle]#

Gibt einen Iterator über die Eigenschaften zurück

__repr__() str[Quelle]#

Gibt eine String-Darstellung der Parameter zurück

__len__() int[Quelle]#

Gibt die Anzahl der Parameter zurück

class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[Quelle]#

Basiert auf: object

__getitem__(name: str) int | float | str[Quelle]#

Ruft die Eigenschaft ab, die mit dem gegebenen Namen verknüpft ist

Sucht den Namen in den Eigenschaften des Checkpoint-Status.

Parameter:

name – Der Name der Eigenschaft

Rückgabe:

Der Wert der Eigenschaft

Ausnahmen:

KeyError – Wenn die Eigenschaft nicht gefunden wird

__setitem__(name: str, value: int | float | str) None[Quelle]#

Setzt den Eigenschaftswert für den gegebenen Namen

Sucht den Namen in den Eigenschaften des Checkpoint-Status. Der Wert wird hinzugefügt oder in den Eigenschaften aktualisiert.

Parameter:
  • name – Der Name der Eigenschaft

  • value – Der Wert der Eigenschaft. Properties unterstützen nur int-, float- und str-Werte.

__contains__(name: str) bool[Quelle]#

Prüft, ob die Eigenschaft im Status vorhanden ist

Parameter:

name – Der Name der Eigenschaft

Rückgabe:

True, wenn der Name eine Eigenschaft ist, sonst False

__iter__()[Quelle]#

Gibt einen Iterator über die Eigenschaften zurück

__repr__() str[Quelle]#

Gibt eine String-Darstellung der Eigenschaften zurück

__len__() int[Quelle]#

Gibt die Anzahl der Eigenschaften zurück

class onnxruntime.training.api.CheckpointState(state: CheckpointState)[Quelle]#

Basiert auf: object

Klasse, die den Status der Trainingssitzung enthält

Diese Klasse enthält alle Statusinformationen der Trainingssitzung, wie z. B. die Modellparameter, ihre Gradienten, den Optimierungsstatus und benutzerdefinierte Eigenschaften.

Um CheckpointState zu erstellen, verwenden Sie die Methode CheckpointState.load_checkpoint.

Parameter:

state – Das C.Checkpoint-Statusobjekt, das den zugrunde liegenden Sitzungsstatus enthält.

classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[Quelle]#

Lädt den Checkpoint-Status aus der Checkpoint-Datei

Die Checkpoint-Datei kann entweder der vollständige Checkpoint oder der nominelle Checkpoint sein.

Parameter:

checkpoint_uri – Der Pfad zur Checkpoint-Datei.

Rückgabe:

Das Checkpoint-Statusobjekt.

Rückgabetyp:

CheckpointState

classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[Quelle]#

Speichert den Checkpoint-Status in der Checkpoint-Datei

Parameter:
  • state – Das Checkpoint-Statusobjekt.

  • checkpoint_uri – Der Pfad zur Checkpoint-Datei.

  • include_optimizer_state – Wenn True, wird auch der Optimierungsstatus in die Checkpoint-Datei gespeichert.

property parameters: Parameters#

Gibt die Modellparameter aus dem Checkpoint-Status zurück

property properties: Properties#

Gibt die Eigenschaften aus dem Checkpoint-Status zurück

class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[Quelle]#

Basiert auf: object

Trainer-Klasse, die Trainings- und Auswertungsfunktionen für ONNX-Modelle bereitstellt.

Vor der Instanziierung der Module-Klasse wird erwartet, dass die Trainingsartefakte mit dem Dienstprogramm onnxruntime.training.artifacts.generate_artifacts generiert wurden.

Die Trainingsartefakte umfassen
  • Das Trainingsmodell

  • Das Auswertungsmodell (optional)

  • Das Optimierungsmodell (optional)

  • Die Checkpoint-Datei

training#

True, wenn sich das Modell im Trainingsmodus befindet, False, wenn es sich im Auswertungsmodus befindet.

Typ:

bool

Parameter:
  • train_model_uri – Der Pfad zum Trainingsmodell.

  • state – Das Checkpoint-Statusobjekt.

  • eval_model_uri – Der Pfad zum Auswertungsmodell.

  • device – Das Gerät, auf dem das Modell ausgeführt werden soll. Standard: "cpu".

  • session_options – Die Sitzungsoptionen, die für das Modell verwendet werden sollen.

__call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[Quelle]#

Ruft entweder den Trainings- oder den Auswertungsschritt des Modells auf.

Parameter:

*user_inputs – Die Eingaben für das Modell. Die Benutzereingaben können entweder NumPy-Arrays oder OrtValues sein.

Rückgabe:

Die Ausgaben des Modells.

train(mode: bool = True) Module[Quelle]#

Setzt das Modul in den Trainingsmodus.

Parameter:

mode – Ob das Modell in den Trainingsmodus (True) oder den Auswertungsmodus (False) gesetzt werden soll. Standard: True.

Rückgabe:

self

eval() Module[Quelle]#

Setzt das Modul in den Auswertungsmodus.

Rückgabe:

self

lazy_reset_grad()[Quelle]#

Setzt die Trainingsgradienten verzögert zurück.

Diese Funktion setzt den internen Status des Moduls so, dass die Modulgradienten unmittelbar vor der Berechnung neuer Gradienten bei der nächsten Ausführung von train() zurückgesetzt werden.

get_contiguous_parameters(trainable_only: bool = False) OrtValue[Quelle]#

Erstellt einen zusammenhängenden Puffer der Trainingssitzungsparameter

Parameter:

trainable_only – Wenn True, werden nur trainierbare Parameter berücksichtigt. Andernfalls werden alle Parameter berücksichtigt.

Rückgabe:

Der zusammenhängende Puffer der Trainingssitzungsparameter.

get_parameters_size(trainable_only: bool = True) int[Quelle]#

Gibt die Größe der Parameter zurück.

Parameter:

trainable_only – Wenn True, werden nur trainierbare Parameter berücksichtigt. Andernfalls werden alle Parameter berücksichtigt.

Rückgabe:

Die Anzahl der primitiven (z. B. Fließkomma-) Elemente in den Parametern.

copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#

Kopiert den OrtValue-Puffer in die Trainingssitzungsparameter.

Falls das Modul aus einem nominalen Checkpoint geladen wurde, muss diese Funktion aufgerufen werden, um die aktualisierten Parameter auf den Checkpoint zu laden, um ihn zu vervollständigen.

Parameter:

buffer – Der OrtValue-Puffer, der in die Trainingssitzungsparameter kopiert werden soll.

export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None[source]#

Exportiert das Modell für die Inferenz.

Sobald das Training abgeschlossen ist, kann diese Funktion verwendet werden, um die trainingsspezifischen Knoten im ONNX-Modell zu entfernen. Insbesondere führt diese Funktion Folgendes aus:

  • Durchläuft den Trainingsgraphen und identifiziert Knoten, die die gegebenen Ausgabenamen erzeugen.

  • Entfernt alle nachfolgenden Knoten im Graphen, da sie für den Inferenzgraphen nicht relevant sind.

Parameter:
  • inference_model_uri – Der Pfad zum Inferenzmodell.

  • graph_output_names – Die Liste der Ausgabenamen, die für die Inferenz erforderlich sind.

input_names() list[str][source]#

Gibt die Eingabenamen des Trainings- oder Auswertungsmodells zurück.

output_names() list[str][source]#

Gibt die Ausgabenamen des Trainings- oder Auswertungsmodells zurück.

class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#

Basiert auf: object

Klasse, die Methoden zur Aktualisierung der Modellparameter basierend auf den berechneten Gradienten bereitstellt.

Parameter:
  • optimizer_uri – Der Pfad zum Optimierermodell.

  • module – Das zu trainierende Modul.

step() None[source]#

Aktualisiert die Modellparameter basierend auf den berechneten Gradienten.

Diese Methode aktualisiert die Modellparameter, indem sie einen Schritt in Richtung der berechneten Gradienten macht. Der verwendete Optimierer hängt vom bereitgestellten Optimierermodell ab.

set_learning_rate(learning_rate: float) None[source]#

Setzt die Lernrate für den Optimierer.

Parameter:

learning_rate – Die zu setzende Lernrate.

get_learning_rate() float[source]#

Ruft die aktuelle Lernrate des Optimierers ab.

Rückgabe:

Die aktuelle Lernrate.

class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#

Basiert auf: object

Aktualisiert linear die Lernrate im Optimierer

Der lineare Lernraten-Scheduler verringert die Lernrate durch einen linear aktualisierten multiplikativen Faktor von der anfänglichen Lernrate, die in der Trainingssitzung festgelegt wurde, auf 0. Die Verringerung erfolgt nach der anfänglichen Aufwärmphase, in der die Lernrate linear von 0 auf die angegebene anfängliche Lernrate erhöht wird.

Parameter:
  • optimizer – Der onnxruntime-Trainingsoptimierer des Benutzers

  • warmup_step_count – Die Anzahl der Schritte in der Aufwärmphase.

  • total_step_count – Die Gesamtzahl der Trainingsschritte.

  • initial_lr – Die anfängliche Lernrate.

step() None[source]#

Aktualisiert die Lernrate des Optimierers linear.

Diese Methode sollte bei jedem Trainingsschritt aufgerufen werden, um sicherzustellen, dass die Lernrate korrekt angepasst wird.