Unterstützung für den Array-API-Standard#

Hinweis

Die Unterstützung des Array-API-Standards ist noch experimentell und hinter einer Umgebungsvariablen verborgen. Derzeit ist nur ein kleiner Teil der öffentlichen API abgedeckt.

Diese Anleitung beschreibt, wie Sie den Python Array-API-Standard **verwenden** und **Unterstützung dafür hinzufügen**. Dieser Standard ermöglicht es Benutzern, jede Array-API-kompatible Array-Bibliothek sofort mit Teilen von SciPy zu verwenden.

Das RFC definiert, wie SciPy die Unterstützung für den Standard implementiert, wobei das Hauptprinzip „Array-Typ rein ergibt Array-Typ raus“ ist. Darüber hinaus führt die Implementierung eine strengere Validierung von erlaubten Array-ähnlichen Eingaben durch, z. B. werden NumPy-Matrix- und Masked-Array-Instanzen sowie Arrays mit Objekt-dtype abgelehnt.

Im Folgenden wird ein Array-API-kompatibler Namespace als xp bezeichnet.

Verwendung der Unterstützung für den Array-API-Standard#

Um die Unterstützung für den Array-API-Standard zu aktivieren, muss eine Umgebungsvariable gesetzt werden, bevor SciPy importiert wird

export SCIPY_ARRAY_API=1

Dies aktiviert sowohl die Unterstützung für den Array-API-Standard als auch die strengere Eingabevalidierung für Array-ähnliche Argumente. _Beachten Sie, dass diese Umgebungsvariable als temporäre Maßnahme gedacht ist, um inkrementelle Änderungen vorzunehmen und sie in „main“ zu integrieren, ohne die Abwärtskompatibilität sofort zu beeinträchtigen. Wir beabsichtigen nicht, diese Umgebungsvariable langfristig beizubehalten._

Dieses Clustering-Beispiel zeigt die Verwendung mit PyTorch-Tensoren als Eingaben und Rückgabewerte

>>> import torch
>>> from scipy.cluster.vq import vq
>>> code_book = torch.tensor([[1., 1., 1.],
...                           [2., 2., 2.]])
>>> features  = torch.tensor([[1.9, 2.3, 1.7],
...                           [1.5, 2.5, 2.2],
...                           [0.8, 0.6, 1.7]])
>>> code, dist = vq(features, code_book)
>>> code
tensor([1, 1, 0], dtype=torch.int32)
>>> dist
tensor([0.4359, 0.7348, 0.8307])

Beachten Sie, dass das obige Beispiel für PyTorch CPU-Tensoren funktioniert. Für GPU-Tensoren oder CuPy-Arrays ist das erwartete Ergebnis für vq ein TypeError, da vq kompilierten Code in seiner Implementierung verwendet, der auf der GPU nicht funktioniert.

Die strengere Array-Eingabevalidierung lehnt np.matrix und np.ma.MaskedArray Instanzen sowie Arrays mit object dtype ab

>>> import numpy as np
>>> from scipy.cluster.vq import vq
>>> code_book = np.array([[1., 1., 1.],
...                       [2., 2., 2.]])
>>> features  = np.array([[1.9, 2.3, 1.7],
...                       [1.5, 2.5, 2.2],
...                       [0.8, 0.6, 1.7]])
>>> vq(features, code_book)
(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))

>>> # The above uses numpy arrays; trying to use np.matrix instances or object
>>> # arrays instead will yield an exception with `SCIPY_ARRAY_API=1`:
>>> vq(np.asmatrix(features), code_book)
...
TypeError: 'numpy.matrix' are not supported

>>> vq(np.ma.asarray(features), code_book)
...
TypeError: 'numpy.ma.MaskedArray' are not supported

>>> vq(features.astype(np.object_), code_book)
...
TypeError: object arrays are not supported

Derzeit unterstützte Funktionalität#

Die folgenden Module bieten Unterstützung für den Array-API-Standard, wenn die Umgebungsvariable gesetzt ist

Einzelne Funktionen in den obigen Modulen bieten eine Fähigkeitstabelle in der Dokumentation wie die untenstehende. Wenn die Tabelle fehlt, unterstützt die Funktion derzeit keine anderen Backends als NumPy.

Beispiel für eine Fähigkeitstabelle#

Bibliothek

CPU

GPU

NumPy

n/a

CuPy

n/a

PyTorch

JAX

⚠️ kein JIT

Dask

n/a

Im obigen Beispiel hat die Funktion eine gewisse Unterstützung für NumPy, CuPy, PyTorch und JAX-Arrays, aber keine Unterstützung für Dask-Arrays. Einige Backends wie JAX und PyTorch unterstützen nativ mehrere Geräte (CPU und GPU), aber die SciPy-Unterstützung für solche Arrays kann begrenzt sein; beispielsweise wird erwartet, dass diese SciPy-Funktion nur mit JAX-Arrays funktioniert, die sich auf der CPU befinden. Zusätzlich können einige Backends größere Einschränkungen haben; im Beispiel schlägt die Funktion bei der Ausführung innerhalb von jax.jit fehl. Zusätzliche Einschränkungen können im Docstring der Funktion aufgeführt sein.

Während die mit „n/a“ gekennzeichneten Elemente der Tabelle naturgemäß außerhalb des Rahmens liegen, arbeiten wir kontinuierlich daran, den Rest auszufüllen. Das Dask-Wrapping um Backends, die nicht NumPy sind (insbesondere CuPy), liegt derzeit außerhalb des Rahmens, könnte sich aber in Zukunft ändern.

Weitere Informationen finden Sie im Tracker-Issue.

Implementierungshinweise#

Ein wesentlicher Teil der Unterstützung für den Array-API-Standard und spezifische Kompatibilitätsfunktionen für Numpy, CuPy und PyTorch wird über array-api-compat bereitgestellt. Dieses Paket ist über ein Git-Submodul (unter scipy/_lib) im SciPy-Codebase enthalten, sodass keine neuen Abhängigkeiten eingeführt werden.

array-api-compat bietet generische Hilfsfunktionen und fügt Aliase wie xp.concat hinzu (das für NumPy vor der Hinzufügung von np.concat durch NumPy in NumPy 2.0 auf np.concatenate abgebildet wurde). Dies ermöglicht die Verwendung einer einheitlichen API über NumPy, PyTorch, CuPy und JAX (andere Bibliotheken wie Dask sind in Arbeit).

Wenn die Umgebungsvariable nicht gesetzt ist und die Unterstützung für den Array-API-Standard in SciPy daher deaktiviert ist, verwenden wir weiterhin die umschlossene Version des NumPy-Namespaces, nämlich array_api_compat.numpy. Dies sollte das Verhalten von SciPy-Funktionen nicht ändern, da es sich effektiv um den bestehenden numpy-Namespace mit einer Reihe von Aliases und einigen geänderten/hinzugefügten Funktionen für die Unterstützung des Array-API-Standards handelt. Wenn die Unterstützung aktiviert ist, ist xp = array_namespace(input) der Standard-kompatible Namespace, der den Eingabe-Array-Typ einer Funktion zuordnet (z. B. wenn die Eingabe zu cluster.vq.kmeans ein PyTorch-Tensor ist, dann ist xp array_api_compat.torch).

Hinzufügen der Unterstützung für den Array-API-Standard zu einer SciPy-Funktion#

Neue, zu SciPy hinzugefügte Codes sollten, so weit wie möglich, so eng wie möglich dem Array-API-Standard folgen (diese Funktionen sind in der Regel auch Best-Practice-Idiome für die NumPy-Nutzung). Durch das Befolgen des Standards ist das Hinzufügen von Unterstützung für den Array-API-Standard in der Regel unkompliziert, und wir müssen idealerweise keine Anpassungen pflegen.

Verschiedene Hilfsfunktionen sind in scipy._lib._array_api verfügbar — siehe __all__ in diesem Modul für eine Liste der aktuellen Helfer und deren Docstrings für weitere Informationen.

Um Unterstützung zu einer SciPy-Funktion hinzuzufügen, die in einer .py-Datei definiert ist, müssen Sie Folgendes ändern:

  1. Eingabe-Array-Validierung,

  2. Verwendung von xp anstelle von np-Funktionen,

  3. Bei Aufrufen von kompiliertem Code konvertieren Sie das Array vor dem Aufruf in ein NumPy-Array und nach dem Aufruf zurück in den Eingabe-Array-Typ.

Die Eingabe-Array-Validierung verwendet das folgende Muster

xp = array_namespace(arr) # where arr is the input array
# alternatively, if there are multiple array inputs, include them all:
xp = array_namespace(arr1, arr2)

# replace np.asarray with xp.asarray
arr = xp.asarray(arr)
# uses of non-standard parameters of np.asarray can be replaced with _asarray
arr = _asarray(arr, order='C', dtype=xp.float64, xp=xp)

Beachten Sie, dass, wenn eine Eingabe ein Nicht-NumPy-Array-Typ ist, alle Array-ähnlichen Eingaben von diesem Typ sein müssen; der Versuch, Nicht-NumPy-Arrays mit Listen, Python-Skalaren oder anderen beliebigen Python-Objekten zu mischen, führt zu einer Ausnahme. Für NumPy-Arrays werden diese Typen aus Gründen der Abwärtskompatibilität weiterhin akzeptiert.

Wenn eine Funktion nur einmal in kompilierten Code aufruft, verwenden Sie das folgende Muster

x = np.asarray(x)  # convert to numpy right before compiled call(s)
y = _call_compiled_code(x)
y = xp.asarray(y)  # convert back to original array type

Wenn es mehrere Aufrufe an kompilierten Code gibt, stellen Sie sicher, dass die Konvertierung nur einmal erfolgt, um zu viel Overhead zu vermeiden.

Hier ist ein Beispiel für eine hypothetische öffentliche SciPy-Funktion toto

def toto(a, b):
    a = np.asarray(a)
    b = np.asarray(b, copy=True)

    c = np.sum(a) - np.prod(b)

    # this is some C or Cython call
    d = cdist(c)

    return d

Sie würden dies wie folgt konvertieren

def toto(a, b):
    xp = array_namespace(a, b)
    a = xp.asarray(a)
    b = xp_copy(b, xp=xp)  # our custom helper is needed for copy

    c = xp.sum(a) - xp.prod(b)

    # this is some C or Cython call
    c = np.asarray(c)
    d = cdist(c)
    d = xp.asarray(d)

    return d

Der Durchlauf durch kompilierten Code erfordert die Rückkehr zu einem NumPy-Array, da die Erweiterungsmodule von SciPy nur mit NumPy-Arrays (oder Speicheransichten im Fall von Cython) arbeiten. Für Arrays auf der CPU sollten die Konvertierungen verlustfrei sein, während auf GPU und anderen Geräten der Versuch einer Konvertierung eine Ausnahme auslöst. Der Grund dafür ist, dass eine stille Datenübertragung zwischen Geräten als schlechte Praxis angesehen wird, da sie wahrscheinlich ein großes und schwer zu erkennendes Leistungsengpass ist.

Hinzufügen von Tests#

Um einen Test auf mehreren Array-Backends auszuführen, sollten Sie die xp-Fixture zu ihm hinzufügen, die dem aktuell getesteten Array-Namespace zugewiesen ist.

Die folgenden Pytest-Marker sind verfügbar

  • skip_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None): überspringt bestimmte Backends oder Kategorien von Backends. Siehe Docstring von scipy.conftest.skip_or_xfail_xp_backends für Informationen zur Verwendung dieses Markers zum Überspringen von Tests.

  • xfail_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, eager_only=False, exceptions=None): markiert bestimmte Backends oder Kategorien von Backends als fehlschlagend. Siehe Docstring von scipy.conftest.skip_or_xfail_xp_backends für Informationen zur Verwendung dieses Markers, um Tests als fehlschlagend zu markieren.

  • skip_xp_invalid_arg wird verwendet, um Tests zu überspringen, die Argumente verwenden, die ungültig sind, wenn SCIPY_ARRAY_API aktiviert ist. Zum Beispiel übergeben einige Tests von scipy.stats-Funktionen masked arrays an die getestete Funktion, aber masked arrays sind inkompatibel mit der Array-API. Die Verwendung des skip_xp_invalid_arg Dekorators ermöglicht es diesen Tests, sich vor Regressionen zu schützen, wenn SCIPY_ARRAY_API nicht verwendet wird, ohne dass es zu Fehlern kommt, wenn SCIPY_ARRAY_API verwendet wird. Mit der Zeit werden wir möchten, dass diese Funktionen Deprecation Warnings ausgeben, wenn sie ungültige Array-API-Eingaben erhalten, und dieser Dekorator prüft, ob die Deprecation Warning ausgegeben wird, ohne dass der Test fehlschlägt. Wenn SCIPY_ARRAY_API=1 zum Standard- und einzigen Verhalten wird, werden diese Tests (und der Dekorator selbst) entfernt.

  • array_api_backends: Dieser Marker wird automatisch von der xp-Fixture zu allen Tests hinzugefügt, die sie verwenden. Dies ist z. B. nützlich, um alle und nur diese Tests auszuwählen.

    python dev.py test -b all -m array_api_backends
    

scipy._lib._array_api enthält Array-agnostische Assertionen wie xp_assert_close, die zum Ersetzen von Assertionen aus numpy.testing verwendet werden können.

Wenn diese Assertionen innerhalb eines Tests ausgeführt werden, der die xp-Fixture verwendet, erzwingen sie, dass die Namespaces der tatsächlichen und der gewünschten Arrays mit dem Namespace übereinstimmen, der von der Fixture gesetzt wurde. Tests ohne die xp-Fixture leiten den Namespace vom gewünschten Array ab. Diese Maschinerie kann durch explizites Übergeben des xp=-Parameters an die Assertionsfunktionen überschrieben werden.

Die folgenden Beispiele zeigen, wie die Marker verwendet werden

from scipy.conftest import skip_xp_invalid_arg
from scipy._lib._array_api import xp_assert_close
...
@pytest.mark.skip_xp_backends(np_only=True, reason='skip reason')
def test_toto1(self, xp):
    a = xp.asarray([1, 2, 3])
    b = xp.asarray([0, 2, 5])
    xp_assert_close(toto(a, b), a)
...
@pytest.mark.skip_xp_backends('array_api_strict', reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy', reason='skip reason 2')
def test_toto2(self, xp):
    ...
...
# Do not run when SCIPY_ARRAY_API is used
@skip_xp_invalid_arg
def test_toto_masked_array(self):
    ...

Das Übergeben von Backend-Namen an exceptions bedeutet, dass sie nicht von cpu_only=True oder eager_only=True übersprungen werden. Dies ist nützlich, wenn die Delegation für einige, aber nicht alle Nicht-CPU-Backends implementiert ist und der CPU-Code-Pfad eine Konvertierung nach NumPy für kompilierten Code erfordert.

# array-api-strict and CuPy will always be skipped, for the given reasons.
# All libraries using a non-CPU device will also be skipped, apart from
# JAX, for which delegation is implemented (hence non-CPU execution is supported).
@pytest.mark.skip_xp_backends(cpu_only=True, exceptions=['jax.numpy'])
@pytest.mark.skip_xp_backends('array_api_strict', reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy', reason='skip reason 2')
def test_toto(self, xp):
    ...

Nach Anwendung dieser Marker kann dev.py test mit der neuen Option -b oder --array-api-backend verwendet werden.

python dev.py test -b numpy -b torch -s cluster

Dies setzt SCIPY_ARRAY_API automatisch entsprechend. Um eine Bibliothek zu testen, die mehrere Geräte mit einem Nicht-Standard-Gerät hat, kann eine zweite Umgebungsvariable (SCIPY_DEVICE, nur in der Testsuite verwendet) gesetzt werden. Gültige Werte hängen von der getesteten Array-Bibliothek ab, z. B. für PyTorch sind gültige Werte "cpu", "cuda", "mps". Um die Testsuite mit dem PyTorch MPS-Backend auszuführen, verwenden Sie: SCIPY_DEVICE=mps python dev.py test -b torch.

Beachten Sie, dass es einen GitHub Actions Workflow gibt, der mit array-api-strict, PyTorch und JAX auf CPU testet.

Testen des JAX JIT-Compilers#

Der JAX JIT-Compiler führt spezielle Einschränkungen für allen von @jax.jit umschlossenen Code ein, die beim Ausführen von JAX im Eager-Modus nicht vorhanden sind. Insbesondere werden boolesche Masken in __getitem__ und at nicht unterstützt, und Sie können die Arrays nicht materialisieren, indem Sie bool(), float(), np.asarray() usw. darauf anwenden.

Um SciPy korrekt mit JAX zu testen, müssen Sie die getesteten SciPy-Funktionen mit @jax.jit taggen, bevor sie von den Unit-Tests aufgerufen werden. Um dies zu erreichen, sollten Sie sie in Ihrem Testmodul wie folgt taggen:

from scipy._lib._lazy_testing import lazy_xp_function
from scipy.mymodule import toto

lazy_xp_function(toto)

def test_toto(xp):
    a = xp.asarray([1, 2, 3])
    b = xp.asarray([0, 2, 5])
    # When xp==jax.numpy, toto is wrapped with @jax.jit
    xp_assert_close(toto(a, b), a)

Vollständige Dokumentation finden Sie in scipy/_lib/_lazy_testing.py.

Zusätzliche Informationen#

Hier sind einige zusätzliche Ressourcen, die einige Designentscheidungen motiviert und während der Entwicklungsphase geholfen haben

  • Erster PR mit einigen Diskussionen

  • Schnellstart von diesem PR und einige Inspirationen von scikit-learn.

  • PR zur Hinzufügung von Array-API-Unterstützung zu scikit-learn

  • Einige andere relevante scikit-learn PRs: #22554 und #25956