API#

class onnxruntime.training.ORTModule(module: Module, debug_options: Optional[DebugOptions] = None)[Quelle]#

Bases: Module

Erweitert das torch.nn.Module-Modell des Benutzers, um die extrem schnelle Trainings-Engine von ONNX Runtime zu nutzen.

ORTModule spezialisiert das torch.nn.Module-Modell des Benutzers und stellt forward(), backward() sowie alle anderen APIs von torch.nn.Module bereit.

Parameter:
  • module (torch.nn.Module) – PyTorch-Modul des Benutzers, das ORTModule spezialisiert

  • debug_options (DebugOptions, optional) – Debugging-Optionen für ORTModule.

Initialisiert den internen Modulzustand, der sowohl von nn.Module als auch von ScriptModule gemeinsam genutzt wird.

forward(*inputs, **kwargs)[Quelle]#

Delegiert den forward()-Pass des PyTorch-Trainings an ONNX Runtime.

Der erste Aufruf von forward führt Einrichtungs- und Prüfschritte durch. Während dieses Aufrufs ermittelt ORTModule, ob das Modul mit ONNX Runtime trainiert werden kann. Aus diesem Grund dauert die Ausführung des ersten forward-Aufrufs länger als bei nachfolgenden Aufrufen. Die Ausführung wird unterbrochen, wenn ONNX Runtime das Modell nicht für das Training verarbeiten kann.

Parameter:
  • inputs – Positionsargumente, variable positionsabhängige Eingaben für die forward-Methode des PyTorch-Moduls.

  • kwargs – Schlüsselwort- und variable Schlüsselwortargumente für die forward-Methode des PyTorch-Moduls.

Rückgabe:

Die Ausgabe, wie sie von der vom PyTorch-Modul des Benutzers definierten forward-Methode erwartet wird. Unterstützte Ausgabewerte sind Tensoren, verschachtelte Sequenzen von Tensoren und verschachtelte Dictionaries von Tensorwerten.

add_module(name: str, module: Optional[Module]) None[Quelle]#

Löst eine ORTModuleTorchModelException-Ausnahme aus, da ORTModule das Hinzufügen von Modulen zu sich selbst nicht unterstützt.

property module#

Das ursprüngliche torch.nn.Module, das dieses Modul umschließt.

Diese Eigenschaft bietet Zugriff auf Methoden und Eigenschaften des ursprünglichen Moduls.

apply(fn: Callable[[Module], None]) T[Quelle]#

Überschreibt apply() zur Delegation der Ausführung an ONNX Runtime.

train(mode: bool = True) T[Quelle]#

Überschreibt train() zur Delegation der Ausführung an ONNX Runtime.

state_dict(destination=None, prefix='', keep_vars=False)[Quelle]#

Überschreibt state_dict() zur Delegation der Ausführung an ONNX Runtime.

load_state_dict(state_dict: OrderedDict[str, Tensor], strict: bool = True)[Quelle]#

Überschreibt load_state_dict() zur Delegation der Ausführung an ONNX Runtime.

register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None[Quelle]#

Überschreibt register_buffer()

register_parameter(name: str, param: Optional[Parameter]) None[Quelle]#

Überschreibt register_parameter()

get_parameter(target: str) Parameter[Quelle]#

Überschreibt get_parameter()

get_buffer(target: str) Tensor[Quelle]#

Überschreibt get_buffer()

parameters(recurse: bool = True) Iterator[Parameter][Quelle]#

Überschreibt parameters()

named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]][Quelle]#

Überschreibt named_parameters()

buffers(recurse: bool = True) Iterator[Tensor][Quelle]#

Überschreibt buffers()

named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Tensor]][Quelle]#

Überschreibt named_buffers()

named_children() Iterator[Tuple[str, Module]][Quelle]#

Überschreibt named_children()

modules() Iterator[Module][Quelle]#

Überschreibt modules()

named_modules(*args, **kwargs)[Quelle]#

Überschreibt named_modules()

bfloat16() T#

Konvertiert alle Gleitkommaparameter und -puffer in den Datentyp bfloat16.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Rückgabe:

self

Rückgabetyp:

Module

children() Iterator[Module]#

Gibt einen Iterator über die direkten Kindermodule zurück.

Yields:

Module – ein Kindermodul

compile(*args, **kwargs)#

Kompiliert die forward-Methode dieses Moduls mithilfe von torch.compile().

Die __call__-Methode dieses Moduls wird kompiliert und alle Argumente werden unverändert an torch.compile() übergeben.

Weitere Informationen zu den Argumenten dieser Funktion finden Sie unter torch.compile().

cpu() T#

Verschiebt alle Modellparameter und -puffer zur CPU.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Rückgabe:

self

Rückgabetyp:

Module

cuda(device: Optional[Union[int, device]] = None) T#

Verschiebt alle Modellparameter und -puffer zur GPU.

Dies bewirkt auch, dass zugehörige Parameter und Puffer zu unterschiedlichen Objekten werden. Daher sollte es vor der Konstruktion des Optimierers aufgerufen werden, wenn das Modul während der Optimierung auf der GPU verbleibt.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Parameter:

device (int, optional) – Wenn angegeben, werden alle Parameter auf dieses Gerät kopiert.

Rückgabe:

self

Rückgabetyp:

Module

double() T#

Konvertiert alle Gleitkommaparameter und -puffer in den Datentyp double.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Rückgabe:

self

Rückgabetyp:

Module

eval() T#

Setzt das Modul in den Auswertungsmodus.

Dies hat nur Auswirkungen auf bestimmte Module. In der Dokumentation der einzelnen Module finden Sie Einzelheiten zu deren Verhalten im Trainings-/Auswertungsmodus, d.h. ob sie betroffen sind, z.B. Dropout, BatchNorm usw.

Dies ist äquivalent zu self.train(False).

Siehe Lokales Deaktivieren der Gradientenberechnung für einen Vergleich zwischen .eval() und mehreren ähnlichen Mechanismen, die damit verwechselt werden könnten.

Rückgabe:

self

Rückgabetyp:

Module

extra_repr() str#

Gibt die zusätzliche Darstellung des Moduls zurück.

Um benutzerdefinierte zusätzliche Informationen auszugeben, sollten Sie diese Methode in Ihren eigenen Modulen neu implementieren. Sowohl einzeilige als auch mehrzeilige Zeichenfolgen sind zulässig.

float() T#

Konvertiert alle Gleitkommaparameter und -puffer in den Datentyp float.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Rückgabe:

self

Rückgabetyp:

Module

get_extra_state() Any#

Gibt zusätzliche Zustände zurück, die in die state_dict des Moduls aufgenommen werden sollen.

Implementieren Sie diese und eine entsprechende set_extra_state() für Ihr Modul, wenn Sie zusätzliche Zustände speichern müssen. Diese Funktion wird beim Aufbau der state_dict() des Moduls aufgerufen.

Beachten Sie, dass zusätzliche Zustände pickelbar sein müssen, um eine funktionierende Serialisierung der state_dict zu gewährleisten. Wir bieten nur Abwärtskompatibilitätsgarantien für die Serialisierung von Tensoren; andere Objekte können die Abwärtskompatibilität brechen, wenn sich ihre serialisierte gepickelte Form ändert.

Rückgabe:

Zusätzlicher Zustand, der in der state_dict des Moduls gespeichert werden soll.

Rückgabetyp:

Objekt

get_submodule(target: str) Module[Quelle]#

Gibt das Untermodul zurück, das durch target angegeben ist, wenn es existiert, andernfalls wird ein Fehler ausgelöst.

Beispielsweise nehmen wir an, Sie haben ein nn.Module A, das wie folgt aussieht:

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(Das Diagramm zeigt ein nn.Module A. A hat ein verschachteltes Untermodul net_b, das selbst zwei Untermodule net_c und linear hat. net_c hat dann ein Untermodul conv.)

Um zu überprüfen, ob wir das linear-Untermodul haben, würden wir get_submodule("net_b.linear") aufrufen. Um zu überprüfen, ob wir das conv-Untermodul haben, würden wir get_submodule("net_b.net_c.conv") aufrufen.

Die Laufzeit von get_submodule ist durch den Verschachtelungsgrad der Module in target begrenzt. Eine Abfrage gegen named_modules erzielt das gleiche Ergebnis, ist aber O(N) in der Anzahl der transitiven Module. Daher sollte für eine einfache Prüfung, ob ein Untermodul existiert, immer get_submodule verwendet werden.

Parameter:

target – Der vollständig qualifizierte String-Name des zu suchenden Untermoduls. (Siehe obiges Beispiel für die Angabe eines vollständig qualifizierten Strings.)

Rückgabe:

Das Untermodul, auf das sich target bezieht.

Rückgabetyp:

torch.nn.Module

Ausnahmen:

AttributeError – Wenn zu irgendeinem Zeitpunkt entlang des Pfades, der sich aus dem Zielstring ergibt, der (Unter-)Pfad zu einem nicht vorhandenen Attributnamen oder einem Objekt, das keine Instanz von nn.Module ist, aufgelöst wird.

half() T#

Wandelt alle Gleitkomma-Parameter und Puffer in den Datentyp half um.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Rückgabe:

self

Rückgabetyp:

Module

ipu(device: Optional[Union[int, device]] = None) T#

Verschiebt alle Modellparameter und Puffer auf das IPU.

Dies macht auch zugehörige Parameter und Puffer zu unterschiedlichen Objekten. Daher sollte es vor der Erstellung des Optimierers aufgerufen werden, wenn das Modul während der Optimierung auf dem IPU liegen soll.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Parameter:

device (int, optional) – Wenn angegeben, werden alle Parameter auf dieses Gerät kopiert.

Rückgabe:

self

Rückgabetyp:

Module

mtia(device: Optional[Union[int, device]] = None) T#

Verschiebt alle Modellparameter und Puffer auf das MTIA.

Dies macht auch zugehörige Parameter und Puffer zu unterschiedlichen Objekten. Daher sollte es vor der Erstellung des Optimierers aufgerufen werden, wenn das Modul während der Optimierung auf dem MTIA liegen soll.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Parameter:

device (int, optional) – Wenn angegeben, werden alle Parameter auf dieses Gerät kopiert.

Rückgabe:

self

Rückgabetyp:

Module

register_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]]) RemovableHandle#

Registriert einen Backward-Hook am Modul.

Diese Funktion ist veraltet zugunsten von register_full_backward_hook() und das Verhalten dieser Funktion wird sich in zukünftigen Versionen ändern.

Rückgabe:

Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von handle.remove() zu entfernen.

Rückgabetyp:

torch.utils.hooks.RemovableHandle

register_forward_hook(hook: Union[Callable[[T, tuple[Any, ...], Any], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle#

Registriert einen Forward-Hook am Modul.

Der Hook wird jedes Mal aufgerufen, nachdem forward() eine Ausgabe berechnet hat.

Wenn with_kwargs False ist oder nicht angegeben wird, enthält die Eingabe nur die positionsbezogenen Argumente, die dem Modul übergeben wurden. Schlüsselwortargumente werden nicht an die Hooks weitergegeben, sondern nur an forward. Der Hook kann die Ausgabe modifizieren. Er kann die Eingabe inplace modifizieren, hat aber keine Auswirkung auf den Forward-Pass, da er nach dem Aufruf von forward() aufgerufen wird. Der Hook sollte die folgende Signatur haben:

hook(module, args, output) -> None or modified output

Wenn with_kwargs True ist, werden dem Forward-Hook die an die Forward-Funktion übergebenen kwargs übergeben und es wird erwartet, dass er die möglicherweise modifizierte Ausgabe zurückgibt. Der Hook sollte die folgende Signatur haben:

hook(module, args, kwargs, output) -> None or modified output
Parameter:
  • hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.

  • prepend (bool) – Wenn True, wird der bereitgestellte hook vor allen vorhandenen forward-Hooks auf diesem torch.nn.Module aufgerufen. Andernfalls wird der bereitgestellte hook nach allen vorhandenen forward-Hooks auf diesem torch.nn.Module aufgerufen. Beachten Sie, dass globale forward-Hooks, die mit register_module_forward_hook() registriert wurden, vor allen Hooks aufgerufen werden, die mit dieser Methode registriert wurden. Standard: False

  • with_kwargs (bool) – Wenn True, wird dem hook die an die Forward-Funktion übergebenen kwargs übergeben. Standard: False

  • always_call (bool) – Wenn True, wird der hook unabhängig davon ausgeführt, ob beim Aufruf des Moduls eine Ausnahme ausgelöst wird. Standard: False

Rückgabe:

Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von handle.remove() zu entfernen.

Rückgabetyp:

torch.utils.hooks.RemovableHandle

register_forward_pre_hook(hook: Union[Callable[[T, tuple[Any, ...]], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]]], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle#

Registriert einen Forward Pre-Hook am Modul.

Der Hook wird jedes Mal aufgerufen, bevor forward() aufgerufen wird.

Wenn with_kwargs falsch ist oder nicht angegeben wird, enthält die Eingabe nur die positionsbezogenen Argumente, die dem Modul übergeben wurden. Schlüsselwortargumente werden nicht an die Hooks weitergegeben, sondern nur an forward. Der Hook kann die Eingabe modifizieren. Der Benutzer kann entweder ein Tupel oder einen einzelnen modifizierten Wert im Hook zurückgeben. Wir wickeln den Wert in ein Tupel ein, wenn ein einzelner Wert zurückgegeben wird (es sei denn, dieser Wert ist bereits ein Tupel). Der Hook sollte die folgende Signatur haben:

hook(module, args) -> None or modified input

Wenn with_kwargs wahr ist, werden dem Forward Pre-Hook die an die Forward-Funktion übergebenen kwargs übergeben. Und wenn der Hook die Eingabe modifiziert, sollten sowohl die args als auch die kwargs zurückgegeben werden. Der Hook sollte die folgende Signatur haben:

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
Parameter:
  • hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.

  • prepend (bool) – Wenn wahr, wird der bereitgestellte hook vor allen vorhandenen forward_pre-Hooks auf diesem torch.nn.Module aufgerufen. Andernfalls wird der bereitgestellte hook nach allen vorhandenen forward_pre-Hooks auf diesem torch.nn.Module aufgerufen. Beachten Sie, dass globale forward_pre-Hooks, die mit register_module_forward_pre_hook() registriert wurden, vor allen Hooks aufgerufen werden, die mit dieser Methode registriert wurden. Standard: False

  • with_kwargs (bool) – Wenn wahr, wird dem hook die an die Forward-Funktion übergebenen kwargs übergeben. Standard: False

Rückgabe:

Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von handle.remove() zu entfernen.

Rückgabetyp:

torch.utils.hooks.RemovableHandle

register_full_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle#

Registriert einen Backward-Hook am Modul.

Der Hook wird jedes Mal aufgerufen, wenn die Gradienten in Bezug auf ein Modul berechnet werden, d.h. der Hook wird ausgeführt, wenn und nur wenn die Gradienten in Bezug auf die Modulausgaben berechnet werden. Der Hook sollte die folgende Signatur haben:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

Die grad_input und grad_output sind Tupel, die die Gradienten in Bezug auf die Eingaben bzw. Ausgaben enthalten. Der Hook sollte seine Argumente nicht modifizieren, kann aber optional einen neuen Gradienten in Bezug auf die Eingabe zurückgeben, der in nachfolgenden Berechnungen anstelle von grad_input verwendet wird. grad_input entspricht nur den als positionsbezogene Argumente übergebenen Eingaben, und alle Kwargs-Argumente werden ignoriert. Einträge in grad_input und grad_output sind None für alle Nicht-Tensor-Argumente.

Aus technischen Gründen erhält die Forward-Funktion dieses Hooks, wenn er auf ein Modul angewendet wird, eine Ansicht jedes Tensors, der an das Modul übergeben wird. Ebenso erhält der Aufrufer eine Ansicht jedes Tensors, der von der Forward-Funktion des Moduls zurückgegeben wird.

Warnung

Die Modifizierung von Eingaben oder Ausgaben inplace ist bei der Verwendung von Backward-Hooks nicht erlaubt und führt zu einem Fehler.

Parameter:
  • hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.

  • prepend (bool) – Wenn wahr, wird der bereitgestellte hook vor allen vorhandenen backward-Hooks auf diesem torch.nn.Module aufgerufen. Andernfalls wird der bereitgestellte hook nach allen vorhandenen backward-Hooks auf diesem torch.nn.Module aufgerufen. Beachten Sie, dass globale backward-Hooks, die mit register_module_full_backward_hook() registriert wurden, vor allen Hooks aufgerufen werden, die mit dieser Methode registriert wurden.

Rückgabe:

Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von handle.remove() zu entfernen.

Rückgabetyp:

torch.utils.hooks.RemovableHandle

register_full_backward_pre_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle#

Registriert einen Backward-Pre-Hook für das Modul.

Der Hook wird jedes Mal aufgerufen, wenn die Gradienten für das Modul berechnet werden. Der Hook sollte die folgende Signatur haben

hook(module, grad_output) -> tuple[Tensor] or None

Der grad_output ist ein Tupel. Der Hook sollte seine Argumente nicht ändern, kann aber optional ein neues Gradientenobjekt bezüglich der Ausgabe zurückgeben, das anstelle von grad_output in nachfolgenden Berechnungen verwendet wird. Einträge in grad_output sind None für alle Nicht-Tensor-Argumente.

Aus technischen Gründen erhält die Forward-Funktion dieses Hooks, wenn er auf ein Modul angewendet wird, eine Ansicht jedes Tensors, der an das Modul übergeben wird. Ebenso erhält der Aufrufer eine Ansicht jedes Tensors, der von der Forward-Funktion des Moduls zurückgegeben wird.

Warnung

Das Ändern von Eingaben vor Ort ist bei der Verwendung von Backward-Hooks nicht erlaubt und führt zu einem Fehler.

Parameter:
  • hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.

  • prepend (bool) – Wenn True, wird der bereitgestellte hook vor allen vorhandenen backward_pre Hooks für dieses torch.nn.Module ausgelöst. Andernfalls wird der bereitgestellte hook nach allen vorhandenen backward_pre Hooks für dieses torch.nn.Module ausgelöst. Beachten Sie, dass globale backward_pre Hooks, die mit register_module_full_backward_pre_hook() registriert wurden, vor allen Hooks ausgelöst werden, die mit dieser Methode registriert wurden.

Rückgabe:

Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von handle.remove() zu entfernen.

Rückgabetyp:

torch.utils.hooks.RemovableHandle

register_load_state_dict_post_hook(hook)#

Registriert einen Post-Hook, der nach dem Aufruf von load_state_dict() des Moduls ausgeführt wird.

Er sollte die folgende Signatur haben:

hook(module, incompatible_keys) -> None

Das Argument module ist das aktuelle Modul, auf dem dieser Hook registriert ist, und das Argument incompatible_keys ist ein NamedTuple, das die Attribute missing_keys und unexpected_keys enthält. missing_keys ist eine list von str mit den fehlenden Schlüsseln und unexpected_keys ist eine list von str mit den unerwarteten Schlüsseln.

Die gegebenen incompatible_keys können bei Bedarf vor Ort geändert werden.

Beachten Sie, dass die Prüfungen, die beim Aufruf von load_state_dict() mit strict=True durchgeführt werden, von den Änderungen beeinflusst werden, die der Hook an missing_keys oder unexpected_keys vornimmt, wie erwartet. Hinzufügungen zu beiden Schlüsselmengen führen zu einem Fehler, wenn strict=True ist, und das Leeren beider fehlenden und unerwarteten Schlüssel vermeidet einen Fehler.

Rückgabe:

Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von handle.remove() zu entfernen.

Rückgabetyp:

torch.utils.hooks.RemovableHandle

register_load_state_dict_pre_hook(hook)#

Registriert einen Pre-Hook, der vor dem Aufruf von load_state_dict() des Moduls ausgeführt wird.

Er sollte die folgende Signatur haben:

hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950

Parameter:

hook (Callable) – Aufrufbare Funktion, die vor dem Laden des state_dict aufgerufen wird.

register_module(name: str, module: Optional[Module]) None#

Alias für add_module().

register_state_dict_post_hook(hook)#

Registriert einen Post-Hook für die Methode state_dict().

Er sollte die folgende Signatur haben:

hook(module, state_dict, prefix, local_metadata) -> None

Die registrierten Hooks können das state_dict vor Ort ändern.

register_state_dict_pre_hook(hook)#

Registriert einen Pre-Hook für die Methode state_dict().

Er sollte die folgende Signatur haben:

hook(module, prefix, keep_vars) -> None

Die registrierten Hooks können verwendet werden, um eine Vorverarbeitung durchzuführen, bevor der state_dict Aufruf erfolgt.

requires_grad_(requires_grad: bool = True) T#

Ändert, ob Autograd Operationen auf Parametern in diesem Modul aufzeichnen soll.

Diese Methode setzt die requires_grad Attribute der Parameter vor Ort.

Diese Methode ist hilfreich zum Einfrieren eines Teils des Moduls für Finetuning oder zum individuellen Trainieren von Teilen eines Modells (z. B. GAN-Training).

Siehe Lokales Deaktivieren der Gradientenberechnung für einen Vergleich zwischen .requires_grad_() und mehreren ähnlichen Mechanismen, die mit ihm verwechselt werden könnten.

Parameter:

requires_grad (bool) – ob Autograd Operationen auf Parametern in diesem Modul aufzeichnen soll. Standard: True.

Rückgabe:

self

Rückgabetyp:

Module

set_extra_state(state: Any) None#

Setzt den zusätzlichen Zustand, der im geladenen state_dict enthalten ist.

Diese Funktion wird von load_state_dict() aufgerufen, um zusätzlichen Zustand im state_dict zu verarbeiten. Implementieren Sie diese Funktion und eine entsprechende get_extra_state() für Ihr Modul, wenn Sie zusätzlichen Zustand in seinem state_dict speichern müssen.

Parameter:

state (dict) – Zusätzlicher Zustand aus dem state_dict

set_submodule(target: str, module: Module, strict: bool = False) None#

Setzt das durch target angegebene Untermodul, falls es existiert, andernfalls wird ein Fehler ausgelöst.

Hinweis

Wenn strict auf False (Standard) gesetzt ist, ersetzt die Methode ein vorhandenes Untermodul oder erstellt ein neues Untermodul, wenn das übergeordnete Modul existiert. Wenn strict auf True gesetzt ist, versucht die Methode nur, ein vorhandenes Untermodul zu ersetzen und löst einen Fehler aus, wenn das Untermodul nicht existiert.

Beispielsweise nehmen wir an, Sie haben ein nn.Module A, das wie folgt aussieht:

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(3, 3, 3)
        )
        (linear): Linear(3, 3)
    )
)

(Das Diagramm zeigt ein nn.Module A. A hat ein verschachteltes Untermodul net_b, das selbst zwei Untermodule net_c und linear hat. net_c hat dann ein Untermodul conv.)

Um Conv2d durch ein neues Untermodul Linear zu überschreiben, könnten Sie set_submodule("net_b.net_c.conv", nn.Linear(1, 1)) aufrufen, wobei strict True oder False sein kann.

Um ein neues Untermodul Conv2d zum vorhandenen net_b Modul hinzuzufügen, rufen Sie set_submodule("net_b.conv", nn.Conv2d(1, 1, 1)) auf.

Im obigen Beispiel, wenn Sie strict=True setzen und set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True) aufrufen, wird ein AttributeError ausgelöst, da net_b kein Untermodul namens conv hat.

Parameter:
  • target – Der vollständig qualifizierte String-Name des zu suchenden Untermoduls. (Siehe obiges Beispiel für die Angabe eines vollständig qualifizierten Strings.)

  • module – Das Modul, auf das das Untermodul gesetzt werden soll.

  • strict – Wenn False, ersetzt die Methode ein vorhandenes Untermodul oder erstellt ein neues Untermodul, wenn das übergeordnete Modul existiert. Wenn True, versucht die Methode nur, ein vorhandenes Untermodul zu ersetzen und löst einen Fehler aus, wenn das Untermodul noch nicht existiert.

Ausnahmen:
  • ValueError – Wenn der target String leer ist oder wenn module keine Instanz von nn.Module ist.

  • AttributeError – Wenn zu irgendeinem Zeitpunkt entlang des Pfades, der sich aus dem target String ergibt, der (Unter-)Pfad auf einen nicht existierenden Attributnamen oder ein Objekt, das keine Instanz von nn.Module ist, aufgelöst wird.

share_memory() T#

Siehe torch.Tensor.share_memory_().

to(*args, **kwargs)#

Verschieben und/oder umwandeln der Parameter und Puffer.

Dies kann aufgerufen werden als

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Die Signatur ähnelt torch.Tensor.to(), akzeptiert aber nur Floating-Point- oder komplexe dtypes. Zusätzlich werden mit dieser Methode nur Floating-Point- oder komplexe Parameter und Puffer in dtype (falls angegeben) umgewandelt. Die ganzzahligen Parameter und Puffer werden nach device verschoben, falls dieses angegeben ist, jedoch mit unveränderten dtypes. Wenn non_blocking gesetzt ist, versucht es, asynchron zum Host zu konvertieren/verschieben (falls möglich), z. B. CPU-Tensoren mit angepinnter Speicherzuweisung auf CUDA-Geräte zu verschieben.

Siehe unten für Beispiele.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Parameter:
  • device (torch.device) – Das gewünschte Gerät der Parameter und Puffer in diesem Modul.

  • dtype (torch.dtype) – Der gewünschte Floating-Point- oder komplexe dtype der Parameter und Puffer in diesem Modul.

  • tensor (torch.Tensor) – Tensor, dessen dtype und Gerät die gewünschten dtype und Gerät für alle Parameter und Puffer in diesem Modul sind.

  • memory_format (torch.memory_format) – Das gewünschte Speicherformat für 4D-Parameter und Puffer in diesem Modul (nur Keyword-Argument).

Rückgabe:

self

Rückgabetyp:

Module

Beispiele

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device: Optional[Union[int, str, device]], recurse: bool = True) T#

Verschiebt die Parameter und Puffer auf das angegebene Gerät, ohne den Speicher zu kopieren.

Parameter:
  • device (torch.device) – Das gewünschte Gerät der Parameter und Puffer in diesem Modul.

  • recurse (bool) – Ob Parameter und Puffer von Untermodulen rekursiv auf das angegebene Gerät verschoben werden sollen.

Rückgabe:

self

Rückgabetyp:

Module

type(dst_type: Union[dtype, str]) T#

Wandelt alle Parameter und Puffer in dst_type um.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Parameter:

dst_type (type oder string) – Der gewünschte Typ.

Rückgabe:

self

Rückgabetyp:

Module

xpu(device: Optional[Union[int, device]] = None) T#

Verschiebt alle Modellparameter und Puffer zu XPU.

Dies macht auch zugehörige Parameter und Puffer zu unterschiedlichen Objekten. Daher sollte es vor der Konstruktion des Optimierers aufgerufen werden, wenn das Modul optimiert werden soll.

Hinweis

Diese Methode modifiziert das Modul direkt (in-place).

Parameter:

device (int, optional) – Wenn angegeben, werden alle Parameter auf dieses Gerät kopiert.

Rückgabe:

self

Rückgabetyp:

Module

zero_grad(set_to_none: bool = True) None#

Setzt die Gradienten aller Modellparameter zurück.

Siehe ähnliche Funktion unter torch.optim.Optimizer für mehr Kontext.

Parameter:

set_to_none (bool) – Anstatt auf Null zu setzen, werden die Grads auf None gesetzt. Siehe torch.optim.Optimizer.zero_grad() für Details.