API#
- class onnxruntime.training.ORTModule(module: Module, debug_options: Optional[DebugOptions] = None)[Quelle]#
Bases:
ModuleErweitert das
torch.nn.Module-Modell des Benutzers, um die extrem schnelle Trainings-Engine von ONNX Runtime zu nutzen.ORTModule spezialisiert das
torch.nn.Module-Modell des Benutzers und stelltforward(),backward()sowie alle anderen APIs vontorch.nn.Modulebereit.- Parameter:
module (torch.nn.Module) – PyTorch-Modul des Benutzers, das ORTModule spezialisiert
debug_options (
DebugOptions, optional) – Debugging-Optionen für ORTModule.
Initialisiert den internen Modulzustand, der sowohl von nn.Module als auch von ScriptModule gemeinsam genutzt wird.
- forward(*inputs, **kwargs)[Quelle]#
Delegiert den
forward()-Pass des PyTorch-Trainings an ONNX Runtime.Der erste Aufruf von forward führt Einrichtungs- und Prüfschritte durch. Während dieses Aufrufs ermittelt ORTModule, ob das Modul mit ONNX Runtime trainiert werden kann. Aus diesem Grund dauert die Ausführung des ersten forward-Aufrufs länger als bei nachfolgenden Aufrufen. Die Ausführung wird unterbrochen, wenn ONNX Runtime das Modell nicht für das Training verarbeiten kann.
- Parameter:
inputs – Positionsargumente, variable positionsabhängige Eingaben für die forward-Methode des PyTorch-Moduls.
kwargs – Schlüsselwort- und variable Schlüsselwortargumente für die forward-Methode des PyTorch-Moduls.
- Rückgabe:
Die Ausgabe, wie sie von der vom PyTorch-Modul des Benutzers definierten forward-Methode erwartet wird. Unterstützte Ausgabewerte sind Tensoren, verschachtelte Sequenzen von Tensoren und verschachtelte Dictionaries von Tensorwerten.
- add_module(name: str, module: Optional[Module]) None[Quelle]#
Löst eine ORTModuleTorchModelException-Ausnahme aus, da ORTModule das Hinzufügen von Modulen zu sich selbst nicht unterstützt.
- property module#
Das ursprüngliche torch.nn.Module, das dieses Modul umschließt.
Diese Eigenschaft bietet Zugriff auf Methoden und Eigenschaften des ursprünglichen Moduls.
- apply(fn: Callable[[Module], None]) T[Quelle]#
Überschreibt
apply()zur Delegation der Ausführung an ONNX Runtime.
- train(mode: bool = True) T[Quelle]#
Überschreibt
train()zur Delegation der Ausführung an ONNX Runtime.
- state_dict(destination=None, prefix='', keep_vars=False)[Quelle]#
Überschreibt
state_dict()zur Delegation der Ausführung an ONNX Runtime.
- load_state_dict(state_dict: OrderedDict[str, Tensor], strict: bool = True)[Quelle]#
Überschreibt
load_state_dict()zur Delegation der Ausführung an ONNX Runtime.
- register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None[Quelle]#
Überschreibt
register_buffer()
- register_parameter(name: str, param: Optional[Parameter]) None[Quelle]#
Überschreibt
register_parameter()
- get_parameter(target: str) Parameter[Quelle]#
Überschreibt
get_parameter()
- get_buffer(target: str) Tensor[Quelle]#
Überschreibt
get_buffer()
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]][Quelle]#
Überschreibt
named_parameters()
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Tensor]][Quelle]#
Überschreibt
named_buffers()
- named_modules(*args, **kwargs)[Quelle]#
Überschreibt
named_modules()
- bfloat16() T#
Konvertiert alle Gleitkommaparameter und -puffer in den Datentyp
bfloat16.Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- Rückgabe:
self
- Rückgabetyp:
- children() Iterator[Module]#
Gibt einen Iterator über die direkten Kindermodule zurück.
- Yields:
Module – ein Kindermodul
- compile(*args, **kwargs)#
Kompiliert die forward-Methode dieses Moduls mithilfe von
torch.compile().Die
__call__-Methode dieses Moduls wird kompiliert und alle Argumente werden unverändert antorch.compile()übergeben.Weitere Informationen zu den Argumenten dieser Funktion finden Sie unter
torch.compile().
- cpu() T#
Verschiebt alle Modellparameter und -puffer zur CPU.
Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- Rückgabe:
self
- Rückgabetyp:
- cuda(device: Optional[Union[int, device]] = None) T#
Verschiebt alle Modellparameter und -puffer zur GPU.
Dies bewirkt auch, dass zugehörige Parameter und Puffer zu unterschiedlichen Objekten werden. Daher sollte es vor der Konstruktion des Optimierers aufgerufen werden, wenn das Modul während der Optimierung auf der GPU verbleibt.
Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- double() T#
Konvertiert alle Gleitkommaparameter und -puffer in den Datentyp
double.Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- Rückgabe:
self
- Rückgabetyp:
- eval() T#
Setzt das Modul in den Auswertungsmodus.
Dies hat nur Auswirkungen auf bestimmte Module. In der Dokumentation der einzelnen Module finden Sie Einzelheiten zu deren Verhalten im Trainings-/Auswertungsmodus, d.h. ob sie betroffen sind, z.B.
Dropout,BatchNormusw.Dies ist äquivalent zu
self.train(False).Siehe Lokales Deaktivieren der Gradientenberechnung für einen Vergleich zwischen .eval() und mehreren ähnlichen Mechanismen, die damit verwechselt werden könnten.
- Rückgabe:
self
- Rückgabetyp:
- extra_repr() str#
Gibt die zusätzliche Darstellung des Moduls zurück.
Um benutzerdefinierte zusätzliche Informationen auszugeben, sollten Sie diese Methode in Ihren eigenen Modulen neu implementieren. Sowohl einzeilige als auch mehrzeilige Zeichenfolgen sind zulässig.
- float() T#
Konvertiert alle Gleitkommaparameter und -puffer in den Datentyp
float.Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- Rückgabe:
self
- Rückgabetyp:
- get_extra_state() Any#
Gibt zusätzliche Zustände zurück, die in die
state_dictdes Moduls aufgenommen werden sollen.Implementieren Sie diese und eine entsprechende
set_extra_state()für Ihr Modul, wenn Sie zusätzliche Zustände speichern müssen. Diese Funktion wird beim Aufbau der state_dict() des Moduls aufgerufen.Beachten Sie, dass zusätzliche Zustände pickelbar sein müssen, um eine funktionierende Serialisierung der
state_dictzu gewährleisten. Wir bieten nur Abwärtskompatibilitätsgarantien für die Serialisierung von Tensoren; andere Objekte können die Abwärtskompatibilität brechen, wenn sich ihre serialisierte gepickelte Form ändert.- Rückgabe:
Zusätzlicher Zustand, der in der
state_dictdes Moduls gespeichert werden soll.- Rückgabetyp:
- get_submodule(target: str) Module[Quelle]#
Gibt das Untermodul zurück, das durch
targetangegeben ist, wenn es existiert, andernfalls wird ein Fehler ausgelöst.Beispielsweise nehmen wir an, Sie haben ein
nn.ModuleA, das wie folgt aussieht:A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )(Das Diagramm zeigt ein
nn.ModuleA.Ahat ein verschachteltes Untermodulnet_b, das selbst zwei Untermodulenet_cundlinearhat.net_chat dann ein Untermodulconv.)Um zu überprüfen, ob wir das
linear-Untermodul haben, würden wirget_submodule("net_b.linear")aufrufen. Um zu überprüfen, ob wir dasconv-Untermodul haben, würden wirget_submodule("net_b.net_c.conv")aufrufen.Die Laufzeit von
get_submoduleist durch den Verschachtelungsgrad der Module intargetbegrenzt. Eine Abfrage gegennamed_moduleserzielt das gleiche Ergebnis, ist aber O(N) in der Anzahl der transitiven Module. Daher sollte für eine einfache Prüfung, ob ein Untermodul existiert, immerget_submoduleverwendet werden.- Parameter:
target – Der vollständig qualifizierte String-Name des zu suchenden Untermoduls. (Siehe obiges Beispiel für die Angabe eines vollständig qualifizierten Strings.)
- Rückgabe:
Das Untermodul, auf das sich
targetbezieht.- Rückgabetyp:
- Ausnahmen:
AttributeError – Wenn zu irgendeinem Zeitpunkt entlang des Pfades, der sich aus dem Zielstring ergibt, der (Unter-)Pfad zu einem nicht vorhandenen Attributnamen oder einem Objekt, das keine Instanz von
nn.Moduleist, aufgelöst wird.
- half() T#
Wandelt alle Gleitkomma-Parameter und Puffer in den Datentyp
halfum.Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- Rückgabe:
self
- Rückgabetyp:
- ipu(device: Optional[Union[int, device]] = None) T#
Verschiebt alle Modellparameter und Puffer auf das IPU.
Dies macht auch zugehörige Parameter und Puffer zu unterschiedlichen Objekten. Daher sollte es vor der Erstellung des Optimierers aufgerufen werden, wenn das Modul während der Optimierung auf dem IPU liegen soll.
Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- mtia(device: Optional[Union[int, device]] = None) T#
Verschiebt alle Modellparameter und Puffer auf das MTIA.
Dies macht auch zugehörige Parameter und Puffer zu unterschiedlichen Objekten. Daher sollte es vor der Erstellung des Optimierers aufgerufen werden, wenn das Modul während der Optimierung auf dem MTIA liegen soll.
Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- register_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]]) RemovableHandle#
Registriert einen Backward-Hook am Modul.
Diese Funktion ist veraltet zugunsten von
register_full_backward_hook()und das Verhalten dieser Funktion wird sich in zukünftigen Versionen ändern.- Rückgabe:
Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von
handle.remove()zu entfernen.- Rückgabetyp:
torch.utils.hooks.RemovableHandle
- register_forward_hook(hook: Union[Callable[[T, tuple[Any, ...], Any], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle#
Registriert einen Forward-Hook am Modul.
Der Hook wird jedes Mal aufgerufen, nachdem
forward()eine Ausgabe berechnet hat.Wenn
with_kwargsFalseist oder nicht angegeben wird, enthält die Eingabe nur die positionsbezogenen Argumente, die dem Modul übergeben wurden. Schlüsselwortargumente werden nicht an die Hooks weitergegeben, sondern nur anforward. Der Hook kann die Ausgabe modifizieren. Er kann die Eingabe inplace modifizieren, hat aber keine Auswirkung auf den Forward-Pass, da er nach dem Aufruf vonforward()aufgerufen wird. Der Hook sollte die folgende Signatur haben:hook(module, args, output) -> None or modified output
Wenn
with_kwargsTrueist, werden dem Forward-Hook die an die Forward-Funktion übergebenenkwargsübergeben und es wird erwartet, dass er die möglicherweise modifizierte Ausgabe zurückgibt. Der Hook sollte die folgende Signatur haben:hook(module, args, kwargs, output) -> None or modified output
- Parameter:
hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.
prepend (bool) – Wenn
True, wird der bereitgestelltehookvor allen vorhandenenforward-Hooks auf diesemtorch.nn.Moduleaufgerufen. Andernfalls wird der bereitgestelltehooknach allen vorhandenenforward-Hooks auf diesemtorch.nn.Moduleaufgerufen. Beachten Sie, dass globaleforward-Hooks, die mitregister_module_forward_hook()registriert wurden, vor allen Hooks aufgerufen werden, die mit dieser Methode registriert wurden. Standard:Falsewith_kwargs (bool) – Wenn
True, wird demhookdie an die Forward-Funktion übergebenen kwargs übergeben. Standard:Falsealways_call (bool) – Wenn
True, wird derhookunabhängig davon ausgeführt, ob beim Aufruf des Moduls eine Ausnahme ausgelöst wird. Standard:False
- Rückgabe:
Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von
handle.remove()zu entfernen.- Rückgabetyp:
torch.utils.hooks.RemovableHandle
- register_forward_pre_hook(hook: Union[Callable[[T, tuple[Any, ...]], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]]], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle#
Registriert einen Forward Pre-Hook am Modul.
Der Hook wird jedes Mal aufgerufen, bevor
forward()aufgerufen wird.Wenn
with_kwargsfalsch ist oder nicht angegeben wird, enthält die Eingabe nur die positionsbezogenen Argumente, die dem Modul übergeben wurden. Schlüsselwortargumente werden nicht an die Hooks weitergegeben, sondern nur anforward. Der Hook kann die Eingabe modifizieren. Der Benutzer kann entweder ein Tupel oder einen einzelnen modifizierten Wert im Hook zurückgeben. Wir wickeln den Wert in ein Tupel ein, wenn ein einzelner Wert zurückgegeben wird (es sei denn, dieser Wert ist bereits ein Tupel). Der Hook sollte die folgende Signatur haben:hook(module, args) -> None or modified input
Wenn
with_kwargswahr ist, werden dem Forward Pre-Hook die an die Forward-Funktion übergebenen kwargs übergeben. Und wenn der Hook die Eingabe modifiziert, sollten sowohl die args als auch die kwargs zurückgegeben werden. Der Hook sollte die folgende Signatur haben:hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
- Parameter:
hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.
prepend (bool) – Wenn wahr, wird der bereitgestellte
hookvor allen vorhandenenforward_pre-Hooks auf diesemtorch.nn.Moduleaufgerufen. Andernfalls wird der bereitgestelltehooknach allen vorhandenenforward_pre-Hooks auf diesemtorch.nn.Moduleaufgerufen. Beachten Sie, dass globaleforward_pre-Hooks, die mitregister_module_forward_pre_hook()registriert wurden, vor allen Hooks aufgerufen werden, die mit dieser Methode registriert wurden. Standard:Falsewith_kwargs (bool) – Wenn wahr, wird dem
hookdie an die Forward-Funktion übergebenen kwargs übergeben. Standard:False
- Rückgabe:
Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von
handle.remove()zu entfernen.- Rückgabetyp:
torch.utils.hooks.RemovableHandle
- register_full_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle#
Registriert einen Backward-Hook am Modul.
Der Hook wird jedes Mal aufgerufen, wenn die Gradienten in Bezug auf ein Modul berechnet werden, d.h. der Hook wird ausgeführt, wenn und nur wenn die Gradienten in Bezug auf die Modulausgaben berechnet werden. Der Hook sollte die folgende Signatur haben:
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
Die
grad_inputundgrad_outputsind Tupel, die die Gradienten in Bezug auf die Eingaben bzw. Ausgaben enthalten. Der Hook sollte seine Argumente nicht modifizieren, kann aber optional einen neuen Gradienten in Bezug auf die Eingabe zurückgeben, der in nachfolgenden Berechnungen anstelle vongrad_inputverwendet wird.grad_inputentspricht nur den als positionsbezogene Argumente übergebenen Eingaben, und alle Kwargs-Argumente werden ignoriert. Einträge ingrad_inputundgrad_outputsindNonefür alle Nicht-Tensor-Argumente.Aus technischen Gründen erhält die Forward-Funktion dieses Hooks, wenn er auf ein Modul angewendet wird, eine Ansicht jedes Tensors, der an das Modul übergeben wird. Ebenso erhält der Aufrufer eine Ansicht jedes Tensors, der von der Forward-Funktion des Moduls zurückgegeben wird.
Warnung
Die Modifizierung von Eingaben oder Ausgaben inplace ist bei der Verwendung von Backward-Hooks nicht erlaubt und führt zu einem Fehler.
- Parameter:
hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.
prepend (bool) – Wenn wahr, wird der bereitgestellte
hookvor allen vorhandenenbackward-Hooks auf diesemtorch.nn.Moduleaufgerufen. Andernfalls wird der bereitgestelltehooknach allen vorhandenenbackward-Hooks auf diesemtorch.nn.Moduleaufgerufen. Beachten Sie, dass globalebackward-Hooks, die mitregister_module_full_backward_hook()registriert wurden, vor allen Hooks aufgerufen werden, die mit dieser Methode registriert wurden.
- Rückgabe:
Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von
handle.remove()zu entfernen.- Rückgabetyp:
torch.utils.hooks.RemovableHandle
- register_full_backward_pre_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle#
Registriert einen Backward-Pre-Hook für das Modul.
Der Hook wird jedes Mal aufgerufen, wenn die Gradienten für das Modul berechnet werden. Der Hook sollte die folgende Signatur haben
hook(module, grad_output) -> tuple[Tensor] or None
Der
grad_outputist ein Tupel. Der Hook sollte seine Argumente nicht ändern, kann aber optional ein neues Gradientenobjekt bezüglich der Ausgabe zurückgeben, das anstelle vongrad_outputin nachfolgenden Berechnungen verwendet wird. Einträge ingrad_outputsindNonefür alle Nicht-Tensor-Argumente.Aus technischen Gründen erhält die Forward-Funktion dieses Hooks, wenn er auf ein Modul angewendet wird, eine Ansicht jedes Tensors, der an das Modul übergeben wird. Ebenso erhält der Aufrufer eine Ansicht jedes Tensors, der von der Forward-Funktion des Moduls zurückgegeben wird.
Warnung
Das Ändern von Eingaben vor Ort ist bei der Verwendung von Backward-Hooks nicht erlaubt und führt zu einem Fehler.
- Parameter:
hook (Callable) – Der vom Benutzer definierte zu registrierende Hook.
prepend (bool) – Wenn True, wird der bereitgestellte
hookvor allen vorhandenenbackward_preHooks für diesestorch.nn.Moduleausgelöst. Andernfalls wird der bereitgestelltehooknach allen vorhandenenbackward_preHooks für diesestorch.nn.Moduleausgelöst. Beachten Sie, dass globalebackward_preHooks, die mitregister_module_full_backward_pre_hook()registriert wurden, vor allen Hooks ausgelöst werden, die mit dieser Methode registriert wurden.
- Rückgabe:
Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von
handle.remove()zu entfernen.- Rückgabetyp:
torch.utils.hooks.RemovableHandle
- register_load_state_dict_post_hook(hook)#
Registriert einen Post-Hook, der nach dem Aufruf von
load_state_dict()des Moduls ausgeführt wird.- Er sollte die folgende Signatur haben:
hook(module, incompatible_keys) -> None
Das Argument
moduleist das aktuelle Modul, auf dem dieser Hook registriert ist, und das Argumentincompatible_keysist einNamedTuple, das die Attributemissing_keysundunexpected_keysenthält.missing_keysist einelistvonstrmit den fehlenden Schlüsseln undunexpected_keysist einelistvonstrmit den unerwarteten Schlüsseln.Die gegebenen incompatible_keys können bei Bedarf vor Ort geändert werden.
Beachten Sie, dass die Prüfungen, die beim Aufruf von
load_state_dict()mitstrict=Truedurchgeführt werden, von den Änderungen beeinflusst werden, die der Hook anmissing_keysoderunexpected_keysvornimmt, wie erwartet. Hinzufügungen zu beiden Schlüsselmengen führen zu einem Fehler, wennstrict=Trueist, und das Leeren beider fehlenden und unerwarteten Schlüssel vermeidet einen Fehler.- Rückgabe:
Ein Handle, das verwendet werden kann, um den hinzugefügten Hook durch Aufrufen von
handle.remove()zu entfernen.- Rückgabetyp:
torch.utils.hooks.RemovableHandle
- register_load_state_dict_pre_hook(hook)#
Registriert einen Pre-Hook, der vor dem Aufruf von
load_state_dict()des Moduls ausgeführt wird.- Er sollte die folgende Signatur haben:
hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
- Parameter:
hook (Callable) – Aufrufbare Funktion, die vor dem Laden des state_dict aufgerufen wird.
- register_state_dict_post_hook(hook)#
Registriert einen Post-Hook für die Methode
state_dict().- Er sollte die folgende Signatur haben:
hook(module, state_dict, prefix, local_metadata) -> None
Die registrierten Hooks können das
state_dictvor Ort ändern.
- register_state_dict_pre_hook(hook)#
Registriert einen Pre-Hook für die Methode
state_dict().- Er sollte die folgende Signatur haben:
hook(module, prefix, keep_vars) -> None
Die registrierten Hooks können verwendet werden, um eine Vorverarbeitung durchzuführen, bevor der
state_dictAufruf erfolgt.
- requires_grad_(requires_grad: bool = True) T#
Ändert, ob Autograd Operationen auf Parametern in diesem Modul aufzeichnen soll.
Diese Methode setzt die
requires_gradAttribute der Parameter vor Ort.Diese Methode ist hilfreich zum Einfrieren eines Teils des Moduls für Finetuning oder zum individuellen Trainieren von Teilen eines Modells (z. B. GAN-Training).
Siehe Lokales Deaktivieren der Gradientenberechnung für einen Vergleich zwischen .requires_grad_() und mehreren ähnlichen Mechanismen, die mit ihm verwechselt werden könnten.
- set_extra_state(state: Any) None#
Setzt den zusätzlichen Zustand, der im geladenen state_dict enthalten ist.
Diese Funktion wird von
load_state_dict()aufgerufen, um zusätzlichen Zustand im state_dict zu verarbeiten. Implementieren Sie diese Funktion und eine entsprechendeget_extra_state()für Ihr Modul, wenn Sie zusätzlichen Zustand in seinem state_dict speichern müssen.- Parameter:
state (dict) – Zusätzlicher Zustand aus dem state_dict
- set_submodule(target: str, module: Module, strict: bool = False) None#
Setzt das durch
targetangegebene Untermodul, falls es existiert, andernfalls wird ein Fehler ausgelöst.Hinweis
Wenn
strictaufFalse(Standard) gesetzt ist, ersetzt die Methode ein vorhandenes Untermodul oder erstellt ein neues Untermodul, wenn das übergeordnete Modul existiert. WennstrictaufTruegesetzt ist, versucht die Methode nur, ein vorhandenes Untermodul zu ersetzen und löst einen Fehler aus, wenn das Untermodul nicht existiert.Beispielsweise nehmen wir an, Sie haben ein
nn.ModuleA, das wie folgt aussieht:A( (net_b): Module( (net_c): Module( (conv): Conv2d(3, 3, 3) ) (linear): Linear(3, 3) ) )(Das Diagramm zeigt ein
nn.ModuleA.Ahat ein verschachteltes Untermodulnet_b, das selbst zwei Untermodulenet_cundlinearhat.net_chat dann ein Untermodulconv.)Um
Conv2ddurch ein neues UntermodulLinearzu überschreiben, könnten Sieset_submodule("net_b.net_c.conv", nn.Linear(1, 1))aufrufen, wobeistrictTrueoderFalsesein kann.Um ein neues Untermodul
Conv2dzum vorhandenennet_bModul hinzuzufügen, rufen Sieset_submodule("net_b.conv", nn.Conv2d(1, 1, 1))auf.Im obigen Beispiel, wenn Sie
strict=Truesetzen undset_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True)aufrufen, wird ein AttributeError ausgelöst, danet_bkein Untermodul namensconvhat.- Parameter:
target – Der vollständig qualifizierte String-Name des zu suchenden Untermoduls. (Siehe obiges Beispiel für die Angabe eines vollständig qualifizierten Strings.)
module – Das Modul, auf das das Untermodul gesetzt werden soll.
strict – Wenn
False, ersetzt die Methode ein vorhandenes Untermodul oder erstellt ein neues Untermodul, wenn das übergeordnete Modul existiert. WennTrue, versucht die Methode nur, ein vorhandenes Untermodul zu ersetzen und löst einen Fehler aus, wenn das Untermodul noch nicht existiert.
- Ausnahmen:
ValueError – Wenn der
targetString leer ist oder wennmodulekeine Instanz vonnn.Moduleist.AttributeError – Wenn zu irgendeinem Zeitpunkt entlang des Pfades, der sich aus dem
targetString ergibt, der (Unter-)Pfad auf einen nicht existierenden Attributnamen oder ein Objekt, das keine Instanz vonnn.Moduleist, aufgelöst wird.
Siehe
torch.Tensor.share_memory_().
- to(*args, **kwargs)#
Verschieben und/oder umwandeln der Parameter und Puffer.
Dies kann aufgerufen werden als
- to(device=None, dtype=None, non_blocking=False)
- to(dtype, non_blocking=False)
- to(tensor, non_blocking=False)
- to(memory_format=torch.channels_last)
Die Signatur ähnelt
torch.Tensor.to(), akzeptiert aber nur Floating-Point- oder komplexedtypes. Zusätzlich werden mit dieser Methode nur Floating-Point- oder komplexe Parameter und Puffer indtype(falls angegeben) umgewandelt. Die ganzzahligen Parameter und Puffer werden nachdeviceverschoben, falls dieses angegeben ist, jedoch mit unveränderten dtypes. Wennnon_blockinggesetzt ist, versucht es, asynchron zum Host zu konvertieren/verschieben (falls möglich), z. B. CPU-Tensoren mit angepinnter Speicherzuweisung auf CUDA-Geräte zu verschieben.Siehe unten für Beispiele.
Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- Parameter:
device (
torch.device) – Das gewünschte Gerät der Parameter und Puffer in diesem Modul.dtype (
torch.dtype) – Der gewünschte Floating-Point- oder komplexe dtype der Parameter und Puffer in diesem Modul.tensor (torch.Tensor) – Tensor, dessen dtype und Gerät die gewünschten dtype und Gerät für alle Parameter und Puffer in diesem Modul sind.
memory_format (
torch.memory_format) – Das gewünschte Speicherformat für 4D-Parameter und Puffer in diesem Modul (nur Keyword-Argument).
- Rückgabe:
self
- Rückgabetyp:
Beispiele
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- to_empty(*, device: Optional[Union[int, str, device]], recurse: bool = True) T#
Verschiebt die Parameter und Puffer auf das angegebene Gerät, ohne den Speicher zu kopieren.
- Parameter:
device (
torch.device) – Das gewünschte Gerät der Parameter und Puffer in diesem Modul.recurse (bool) – Ob Parameter und Puffer von Untermodulen rekursiv auf das angegebene Gerät verschoben werden sollen.
- Rückgabe:
self
- Rückgabetyp:
- type(dst_type: Union[dtype, str]) T#
Wandelt alle Parameter und Puffer in
dst_typeum.Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- xpu(device: Optional[Union[int, device]] = None) T#
Verschiebt alle Modellparameter und Puffer zu XPU.
Dies macht auch zugehörige Parameter und Puffer zu unterschiedlichen Objekten. Daher sollte es vor der Konstruktion des Optimierers aufgerufen werden, wenn das Modul optimiert werden soll.
Hinweis
Diese Methode modifiziert das Modul direkt (in-place).
- zero_grad(set_to_none: bool = True) None#
Setzt die Gradienten aller Modellparameter zurück.
Siehe ähnliche Funktion unter
torch.optim.Optimizerfür mehr Kontext.- Parameter:
set_to_none (bool) – Anstatt auf Null zu setzen, werden die Grads auf None gesetzt. Siehe
torch.optim.Optimizer.zero_grad()für Details.