PyTorch-Modell mit benutzerdefinierten ONNX-Operatoren exportieren

Dieses Dokument erklärt den Prozess des Exports von PyTorch-Modellen mit benutzerdefinierten ONNX Runtime-Ops. Ziel ist es, ein PyTorch-Modell mit Operatoren zu exportieren, die in ONNX nicht unterstützt werden, und ONNX Runtime zu erweitern, um diese benutzerdefinierten Ops zu unterstützen.

Inhalt

Eingebaute Contrib-Ops exportieren

"Contrib Ops" bezieht sich auf die Menge benutzerdefinierter Ops, die in den meisten ORT-Paketen integriert sind. Symbolische Funktionen für alle Contrib-Ops sollten in pytorch_export_contrib_ops.py definiert sein.

Um diese Contrib-Ops zu exportieren, rufen Sie pytorch_export_contrib_ops.register() auf, bevor Sie torch.onnx.export() aufrufen. Zum Beispiel

from onnxruntime.tools import pytorch_export_contrib_ops
import torch

pytorch_export_contrib_ops.register()
torch.onnx.export(...)

Einen benutzerdefinierten Operator exportieren

Um einen benutzerdefinierten Operator zu exportieren, der kein Contrib-Operator ist oder der nicht bereits in pytorch_export_contrib_ops enthalten ist, muss eine benutzerdefinierte Operator-Symbolfunktion geschrieben und registriert werden.

Wir nehmen den Invers-Operator als Beispiel

from torch.onnx import register_custom_op_symbolic

def my_inverse(g, self):
    return g.op("com.microsoft::Inverse", self)

# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>)
register_custom_op_symbolic('::inverse', my_inverse, 1)

<namespace> ist ein Teil des Namens des Torch-Operators. Für Standard-Torch-Operatoren kann der Namespace weggelassen werden.

com.microsoft sollte als benutzerdefinierte Opset-Domäne für ONNX Runtime-Ops verwendet werden. Sie können die benutzerdefinierte Opset-Version während der Op-Registrierung wählen.

Weitere Informationen zum Schreiben einer symbolischen Funktion finden Sie in der torch.onnx-Dokumentation.

ONNX Runtime mit benutzerdefinierten Operatoren erweitern

Der nächste Schritt ist das Hinzufügen eines Op-Schemas und einer Kernel-Implementierung in ONNX Runtime. Details finden Sie unter Benutzerdefinierte Operatoren.

End-to-End-Test: Exportieren und Ausführen

Sobald der benutzerdefinierte Operator im Exporter registriert und in ONNX Runtime implementiert ist, sollten Sie ihn exportieren und mit ONNX Runtime ausführen können.

Unten finden Sie ein Beispielskript zum Exportieren und Ausführen des Invers-Operators als Teil eines Modells.

Das exportierte Modell enthält eine Kombination aus standardmäßigen ONNX-Ops und den benutzerdefinierten Ops.

Dieser Test vergleicht auch die Ausgabe des PyTorch-Modells mit den Ausgaben von ONNX Runtime, um sowohl den Operator-Export als auch die Implementierung zu testen.

import io
import numpy
import onnxruntime
import torch


class CustomInverse(torch.nn.Module):
    def forward(self, x):
        return torch.inverse(x) + x

x = torch.randn(3, 3)

# Export model to ONNX
f = io.BytesIO()
torch.onnx.export(CustomInverse(), (x,), f)

model = CustomInverse()
pt_outputs = model(x)

# Run the exported model with ONNX Runtime
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,)))
ort_outputs = ort_sess.run(None, ort_inputs)

# Validate PyTorch and ONNX Runtime results
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)

Standardmäßig wird die Opset-Version für benutzerdefinierte Opset für 1 eingestellt. Wenn Sie Ihren benutzerdefinierten Operator in eine höhere Opset-Version exportieren möchten, können Sie die benutzerdefinierte Opset-Domäne und -Version über das Argument custom_opsets beim Aufruf der Export-API angeben. Beachten Sie, dass dies anders ist als die Opset-Version, die mit der Standard- ONNX -Domäne verbunden ist.

torch.onnx.export(CustomInverse(), (x,), f, custom_opsets={"com.microsoft": 5})

Beachten Sie, dass Sie einen benutzerdefinierten Operator in jede Version exportieren können, die >= der bei der Registrierung verwendeten Opset-Version ist.