Übersicht#
onnxruntime-training’s ORTModule bietet eine Hochleistungs-Trainings-Engine für Modelle, die mit dem PyTorch-Frontend definiert sind. ORTModule wurde entwickelt, um das Training großer Modelle zu beschleunigen, ohne die Modelldefinition oder den Trainingscode ändern zu müssen.
Ziel von ORTModule ist es, einen Drop-in-Ersatz für eines oder mehrere torch.nn.Module-Objekte in einem PyTorch-Programm des Benutzers bereitzustellen und die Vorwärts- und Rückwärtsdurchläufe dieser Module mithilfe von ORT auszuführen.
Infolgedessen kann der Benutzer sein Trainingsskript mit ORT beschleunigen, ohne seine Trainingsschleife ändern zu müssen.
Benutzer können Standard-PyTorch-Debugging-Techniken für Konvergenzprobleme verwenden, z. B. durch Überprüfung der berechneten Gradienten für die Parameter des Modells.
Das folgende Codebeispiel veranschaulicht, wie ORTModule in einem Trainingsskript eines Benutzers verwendet würde, im einfachen Fall, in dem das gesamte Modell an ONNX Runtime ausgelagert werden kann
from onnxruntime.training import ORTModule
# Original PyTorch model
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
...
def forward(self, x):
...
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
model = ORTModule(model) # The only change to the original PyTorch script
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# Training Loop is unchanged
for data, target in data_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()