scipy
ENH: array types, signal: delegate to CuPy and JAX for correlations and convolutions
#20772
Merged

ENH: array types, signal: delegate to CuPy and JAX for correlations and convolutions #20772

ev-br merged 10 commits into scipy:main from ev-br:sigtools_convolve_cupy
ev-br
ev-br362 days ago (edited 287 days ago)

Towards gh-20678

What does this implement/fix?

Delegate signal.convolve and its ilk to cupyx.scipy.signal.convolve if inputs are cupy arrays. For other array types, do the usual convert to numpy, run, convert back dance.

Additional information

CuPy provides a near-complete clone of the scipy API in the cupyx.scipy namespace. We can treat these CuPy functions as accelerators: if a scipy function detects that its arguments are cupy-compatible, it can delegate all work to the cupyx function.

ev-br ev-br requested a review from larsoner larsoner 362 days ago
ev-br ev-br requested a review from ilayn ilayn 362 days ago
github-actions github-actions added scipy.signal
github-actions github-actions added RFC
lucascolley
lucascolley commented on 2024-05-23
Conversation is marked as resolved
Show resolved
scipy/signal/_signaltools.py
1315 else:
1316 can_dispatch = True
1317
1318
# XXX: inputs are cupy arrays and cupyx.scipy.convolve cannot be used
1319
# do we want to convert to numpy/use the CPU version/convert back instead?
lucascolley362 days ago (edited 362 days ago)
  • If the only fallback implementation is np-only, then I think the array types RFC says to just error out instead of forcing the device transfer (which happens naturally if you try to put a CuPy array in np.asarray).
  • If we have an xp-agnostic implementation (albeit a slower one than can be achieved with an xp-specific implementation) which avoids device transfers, then it makes sense to fall back to that.
ev-br362 days ago👍 1

There's no xp-agnostic implementation, for now at least. Let's revisit if/when one arrives then.

Conversation is marked as resolved
Show resolved
scipy/signal/_signaltools.py
lucascolley362 days ago

If you want to try extending this support to other libraries which have a scipy namespace, consider using scipy_namespace_for(xp). That only has JAX in addition right now (since there is no torch.scipy), but it can still help reduce the LOC for dispatching a little, e.g. get_array_special_func.

ev-br362 days ago

scipy_namespace_for is broken on cupyx unless the specific submodule is already imported:

In [1]: from scipy._lib._array_api import scipy_namespace_for, array_namespace

In [2]: import cupy

In [3]: xp = array_namespace(cupy.ones(3))

In [4]: xps = scipy_namespace_for(xp)

In [5]: xps.signal
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 1
----> 1 xps.signal

AttributeError: module 'cupyx.scipy' has no attribute 'signal'

In [6]: getattr(xps, 'signal')
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 1
----> 1 getattr(xps, 'signal')

AttributeError: module 'cupyx.scipy' has no attribute 'signal'

In [7]: import cupyx

In [8]: getattr(cupyx.scipy, 'signal')
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[8], line 1
----> 1 getattr(cupyx.scipy, 'signal')

AttributeError: module 'cupyx.scipy' has no attribute 'signal'

In [9]: import cupyx.scipy.signal
/home/ev-br/.conda/envs/scipy-dev/lib/python3.11/site-packages/cupyx/jit/_interface.py:173: FutureWarning: cupyx.jit.rawkernel is experimental. The interface can change in the future.
  cupy._util.experimental('cupyx.jit.rawkernel')

In [10]: getattr(xps, 'signal')
Out[10]: <module 'cupyx.scipy.signal' from '/home/ev-br/.conda/envs/scipy-dev/lib/python3.11/site-packages/cupyx/scipy/signal/__init__.py'>
lucascolley362 days ago

I'm currently GPU-less, but have you tried getattr(xps.signal, func.__name__)? That's pretty much what we do in special right now - is that broken? Running the special alternative backends tests on CuPy should show if it is broken, no?

ev-br362 days ago
In [1]: import cupyx

In [2]: import cupyx.scipy

In [3]: getattr(cupyx.scipy, 'signal.convolve')
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], line 1
----> 1 getattr(cupyx.scipy, 'signal.convolve')

AttributeError: module 'cupyx.scipy' has no attribute 'signal.convolve'
lucascolley362 days ago (edited 362 days ago)

The example in special is:

import cupyx
getattr(cupyx.scipy.special, 'gammainc', None)

What happens if you run that? Is it different if you run python dev.py test -b copy -t scipy.special.tests.test_support_alternative_backends?

ev-br361 days ago👀 1

special seems special in cupy:

In [1]: import cupyx
   ...: getattr(cupyx.scipy.special, 'gammainc', None)
Out[1]: <ufunc 'cupyx_scipy_gammainc'>

In [2]: import cupyx
   ...: getattr(cupyx.scipy.signal, 'convolve', None)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 2
      1 import cupyx
----> 2 getattr(cupyx.scipy.signal, 'convolve', None)

AttributeError: module 'cupyx.scipy' has no attribute 'signal'

In [3]: import cupyx
   ...: getattr(cupyx.scipy.linalg, 'solve', None)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], line 2
      1 import cupyx
----> 2 getattr(cupyx.scipy.linalg, 'solve', None)

AttributeError: module 'cupyx.scipy' has no attribute 'linalg'

In [4]: import cupyx
   ...: getattr(cupyx.scipy.ndimage, 'correlate', None)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 2
      1 import cupyx
----> 2 getattr(cupyx.scipy.ndimage, 'correlate', None)

AttributeError: module 'cupyx.scipy' has no attribute 'ndimage'

and indeed $ python dev.py test -b cupy -t scipy.special.tests.test_support_alternative_backends is green for me locally.

rgommers361 days ago

So can we just add the minimal tweak to make special less special? I'd really like the mechanism of calling out to the scipy.submodule namespace in another library to not be so different between various submodules. It's be better to have a file _support_alternative_backends.py also in signal, and have the same kind of exposing of it in __init__.py rather than with a decorator.

lucascolley361 days ago

Indeed, it sounds like @kmaehashi is open to making cupyx.scipy behave like scipy w.r.t imports in cupy/cupy#8336

lucascolley lucascolley changed the title RFC/POC: dispatch from `scipy.submodule.func` to `cupyx.scipy.submodule.func` RFC/POC: array types: dispatch to `cupyx.scipy.submodule.func` 362 days ago
lucascolley lucascolley added array types
lucascolley lucascolley changed the title RFC/POC: array types: dispatch to `cupyx.scipy.submodule.func` RFC/POC: array types, signal: dispatch to `cupyx.scipy.submodule.func` 362 days ago
ev-br
lucascolley
lucascolley
lucascolley commented on 2024-05-23
Conversation is marked as resolved
Show resolved
scipy/signal/_signaltools.py
1345 cupyx_func = getattr(cupyx_module, func.__name__)
1346 return cupyx_func(*args, **kwds)
1347 elif can_dispatch and is_jax(xp):
1348
xps = scipy_namespace_for(xp)
lucascolley362 days ago👍 1

we should probably decide on what this variable should be called for consistency. I originally went for spx for the intended meaning of an x{p}-version of sp (scipy). I think that's a little clearer than xps (which is quite similar to eps), although I wouldn't be opposed to either xp_sp or xp_scipy.

lucascolley
lucascolley commented on 2024-05-23
Conversation is marked as resolved
Show resolved
scipy/signal/tests/test_signaltools.py
96 b = np.array([3920])
8297 c = convolve(a, b)
83 assert_equal(c, a * b)
84
85 def test_2d_arrays(self):
86 a = [[1, 2, 3], [3, 4, 5]]
87 b = [[2, 3, 4], [4, 5, 6]]
98 xp_assert_equal(c, a * b)
99
100
# @skip_xp_backends("jax.numpy") XXX: how to skip two backends
lucascolley362 days ago👍 1
Suggested change
# @skip_xp_backends("jax.numpy") XXX: how to skip two backends
@skip_xp_backends("cupy")
@skip_xp_backends("jax.numpy", "cupy")
lucascolley362 days ago (edited 362 days ago)

refer to the docstring at

scipy/scipy/conftest.py

Lines 173 to 199 in b421cd6

def skip_xp_backends(xp, request):
"""
Skip based on the ``skip_xp_backends`` marker.
Parameters
----------
*backends : tuple
Backends to skip, e.g. ``("array_api_strict", "torch")``.
These are overriden when ``np_only`` is ``True``, and are not
necessary to provide for non-CPU backends when ``cpu_only`` is ``True``.
reasons : list, optional
A list of reasons for each skip. When ``np_only`` is ``True``,
this should be a singleton list. Otherwise, this should be a list
of reasons, one for each corresponding backend in ``backends``.
If unprovided, default reasons are used. Note that it is not possible
to specify a custom reason with ``cpu_only``. Default: ``None``.
np_only : bool, optional
When ``True``, the test is skipped for all backends other
than the default NumPy backend. There is no need to provide
any ``backends`` in this case. To specify a reason, pass a
singleton list to ``reasons``. Default: ``False``.
cpu_only : bool, optional
When ``True``, the test is skipped on non-CPU devices.
There is no need to provide any ``backends`` in this case,
but any ``backends`` will also be skipped on the CPU.
Default: ``False``.
"""

The parameter is called reasons not reason, and accepts a list(str)

ev-br ev-br force pushed from e8d894b4 to 1443958c 362 days ago
ev-br
lucascolley
lucascolley
lucascolley commented on 2024-05-23
lucascolley
lucascolley commented on 2024-05-23
lucascolley lucascolley changed the title RFC/POC: array types, signal: dispatch to `cupyx.scipy.submodule.func` RFC/POC: array types, signal: dispatch to CuPy and JAX 361 days ago
lucascolley lucascolley removed review request from ilayn ilayn 361 days ago
lucascolley lucascolley removed review request from larsoner larsoner 361 days ago
lucascolley
lucascolley commented on 2024-05-23
lucascolley361 days ago

For the CI failures, I suspect you need to restrict pytest.mark.usefixtures("skip_xp_backends") to functions/classes you have converted, rather than putting it in pytestmark, for now.

lucascolley
lucascolley
lucascolley commented on 2024-05-23
ev-br ev-br requested a review from andyfaff andyfaff 361 days ago
lucascolley lucascolley removed review request from andyfaff andyfaff 361 days ago
ev-br ev-br force pushed from 2c5748af to b59d72db 361 days ago
ev-br ev-br force pushed from 403ed97a to b9e4f125 359 days ago
ev-br
lucascolley
lucascolley commented on 2024-05-25
lucascolley359 days ago👍 1

@mdhaber could you take a look at _support_alternative_backends.py as the author of the equivalent in special?

tylerjereddy
tylerjereddy commented on 2024-05-25
rgommers
rgommers commented on 2024-05-26
rgommers
rgommers
rgommers commented on 2024-05-26
ev-br ev-br force pushed from e1253441 to edc9e552 358 days ago
ev-br ev-br marked this pull request as draft 358 days ago
ev-br
ev-br commented on 2024-05-27
ev-br ev-br force pushed from 25cd9b65 to 01406df0 358 days ago
ev-br ev-br force pushed from fb8d4b25 to 53d0802d 358 days ago
ev-br
ev-br ev-br marked this pull request as ready for review 357 days ago
ev-br ev-br requested a review from person142 person142 357 days ago
ev-br ev-br requested a review from steppi steppi 357 days ago
ev-br
mdhaber
mdhaber
ev-br
mdhaber
ev-br
mdhaber
ev-br
ev-br
lucascolley
lucascolley commented on 2024-05-28
rgommers
mdhaber
ev-br
mdhaber
lucascolley
ev-br
ev-br
lucascolley
ev-br
lucascolley
lucascolley
lucascolley
ev-br
rgommers
lucascolley
ev-br ev-br force pushed from 57577fa8 to bab897fa 289 days ago
ev-br ev-br marked this pull request as draft 289 days ago
lucascolley
ev-br
ev-br ev-br force pushed from 0431cb32 to 1ace81be 288 days ago
ev-br ev-br force pushed from bc4aaac6 to a433f05e 288 days ago
ev-br
lucascolley
ev-br
ev-br ev-br changed the title RFC/POC: array types, signal: dispatch to CuPy and JAX array types, signal: dispatch to CuPy and JAX for correlations and convolutions 287 days ago
ev-br ev-br marked this pull request as ready for review 287 days ago
ev-br
lucascolley lucascolley changed the title array types, signal: dispatch to CuPy and JAX for correlations and convolutions ENH: array types, signal: dispatch to CuPy and JAX for correlations and convolutions 287 days ago
rgommers
rgommers commented on 2024-08-13
rgommers
rgommers
rgommers commented on 2024-08-13
rgommers280 days ago

I did a pass through this PR (looking more carefully at the code changes compared to the test changes). Overall this looks pretty good to me. I have a number of comments which would be good to address; all local things though - the approach looks good to me.

lucascolley
ev-br
lucascolley
lucascolley
rgommers
rgommers
lucascolley
ev-br
rgommers
rgommers
ev-br
ev-br
lucascolley
lucascolley
rgommers
lucascolley lucascolley changed the title ENH: array types, signal: dispatch to CuPy and JAX for correlations and convolutions ENH: array types, signal: delegate to CuPy and JAX for correlations and convolutions 279 days ago
ev-br ev-br force pushed from 733602a2 to 6263ceba 257 days ago
ev-br ev-br force pushed from 6263ceba to 8d0a1983 257 days ago
ev-br ev-br force pushed from 8d0a1983 to c004d1b9 257 days ago
ev-br ev-br removed RFC
ev-br
ev-br ev-br force pushed from d192207b to 757d473a 238 days ago
ev-br
lucascolley
lucascolley commented on 2024-09-25
lucascolley237 days ago

Some comments, I haven't looked at the test changes yet

ev-br ev-br force pushed from 757d473a to 95ff56b1 237 days ago
ev-br ev-br force pushed from 3200325a to a3466ef3 236 days ago
ev-br ev-br force pushed from a3466ef3 to 2b45c2b3 236 days ago
ev-br ev-br force pushed from 2b45c2b3 to 294202bf 236 days ago
ev-br
lucascolley
lucascolley commented on 2024-09-26
lucascolley236 days ago

no promises on when I'll find the time to go through the test changes, sorry!

tylerjereddy
tylerjereddy commented on 2024-09-27
tylerjereddy235 days ago

On this branch locally with SCIPY_DEVICE=cuda python dev.py test -j 8 -b all, I see an additional 1238 failures relative to the main branch.

Looks like it is all torch on GPU doing device transfers and erroring out.

The diff below the fold applies skips a bit aggressively, but deals with all the failures. We might want to be more selective than that though as it may disable a bunch of CuPy testing. That said, for many test classes, quite a few different contained tests fail, so it would also be annoying to mark each one.

diff --git a/scipy/signal/tests/test_signaltools.py b/scipy/signal/tests/test_signaltools.py
index 750cd0dab2..51530ad784 100644
--- a/scipy/signal/tests/test_signaltools.py
+++ b/scipy/signal/tests/test_signaltools.py
@@ -39,6 +39,7 @@ skip_xp_backends = pytest.mark.skip_xp_backends
 pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
 
 
+@skip_xp_backends(cpu_only=True)
 class TestConvolve:
 
     @skip_xp_backends("jax.numpy",
@@ -280,6 +281,7 @@ class TestConvolve:
 
 
 
+@skip_xp_backends(cpu_only=True)
 class TestConvolve2d:
 
     @skip_xp_backends("jax.numpy", reasons=["dtypes do not match"])
@@ -494,6 +496,7 @@ class TestConvolve2d:
 
 
 
+@skip_xp_backends(cpu_only=True)
 class TestFFTConvolve:
 
     @skip_xp_backends("torch", reasons=["dtypes do not match"])
@@ -949,7 +952,8 @@ def gen_oa_shapes_eq(sizes):
 
 
 
-@skip_xp_backends("jax.numpy", reasons=["fails all around"])
+@skip_xp_backends("jax.numpy", reasons=["fails all around"],
+                  cpu_only=True)
 class TestOAConvolve:
     @pytest.mark.slow()
     @pytest.mark.parametrize('shape_a_0, shape_b_0',
@@ -1184,6 +1188,7 @@ class TestAllFreqConvolves:
         assert res.dtype == dtype
 
 
+@skip_xp_backends(cpu_only=True)
 class TestMedFilt:
 
     IN = [[50, 50, 50, 50, 50, 92, 18, 27, 65, 46],
@@ -1330,7 +1335,7 @@ class TestMedFilt:
 
 class TestWiener:
 
-    @skip_xp_backends("cupy", reasons=["XXX: can_cast in cupy <= 13.2"])
+    @skip_xp_backends("cupy", reasons=["XXX: can_cast in cupy <= 13.2"], cpu_only=True)
     def test_basic(self, xp):
         g = xp.asarray([[5, 6, 4, 3],
                         [3, 5, 6, 2],
@@ -2249,7 +2254,8 @@ class _TestCorrelateReal:
                                 "uint64", "int64",
                                 "float32", "float64",
                                ])
-@skip_xp_backends("jax.numpy", reasons=["fails all around"])
+@skip_xp_backends("jax.numpy", reasons=["fails all around"],
+                  cpu_only=True)
 class TestCorrelateReal(_TestCorrelateReal):
     pass
 
@@ -2297,7 +2303,7 @@ class TestCorrelate:
         assert_raises(ValueError, correlate, [1], [[2]])
         assert_raises(ValueError, correlate, [3], 2)
 
-    @skip_xp_backends("jax.numpy", reasons=["dtype differs"])
+    @skip_xp_backends("jax.numpy", reasons=["dtype differs"], cpu_only=True)
     def test_numpy_fastpath(self, xp):
         a = xp.asarray([1, 2, 3])
         b = xp.asarray([4, 5])
@@ -2349,6 +2355,7 @@ def test_correlation_lags(mode, behind, input_size, xp):
     assert lags.shape == correlation.shape
 
 
+@skip_xp_backends(cpu_only=True)
 @pytest.mark.parametrize('dt_name', ['complex64', 'complex128'])
 class TestCorrelateComplex:
     # The decimal precision to be used for comparing results.
@@ -2471,6 +2478,7 @@ class TestCorrelateComplex:
                         check_shape=False)
 
 
+@skip_xp_backends(cpu_only=True)
 class TestCorrelate2d:
 
     def test_consistency_correlate_funcs(self, xp):
@@ -3885,6 +3893,7 @@ class TestDeconvolve:
             quotient, remainder = signal.deconvolve(recorded, impulse_response)
 
 
+@skip_xp_backends(cpu_only=True)
 class TestDetrend:
 
     def test_basic(self, xp):
ev-br ev-br force pushed from 6e89e1c4 to 3a4642b6 235 days ago
ev-br
ev-br ev-br force pushed from 3a4642b6 to 9c261243 235 days ago
lucascolley
lucascolley commented on 2024-09-27
lucascolley234 days ago

I resolved all open comments that looked resolved to me and spotted one problem.

If you would rather not have small reviews like this and would prefer only a full review at a time, let me know. I may not have time for that soon, though.

ev-br ev-br force pushed from 9c261243 to 3442c4f7 234 days ago
ev-br ev-br force pushed from 3442c4f7 to 706a0377 234 days ago
ev-br
ev-br ev-br force pushed from 706a0377 to 22a55e95 199 days ago
ev-br ev-br force pushed from 22a55e95 to 1d17a321 199 days ago
ev-br ev-br force pushed from 1d17a321 to 6d825ee9 197 days ago
ev-br ev-br force pushed from 6d825ee9 to 21619768 165 days ago
j-bowhay
j-bowhay requested changes on 2024-12-08
j-bowhay162 days ago

I had a look over this, spotted a few things, and probably missed a bunch more! Maybe we should just merge this since since there's a decent period before the next release to iron out any issues

lucascolley
lucascolley
lucascolley commented on 2024-12-08
lucascolley162 days ago

approving the production code changes! (modulo a few observations)

I'll take a look at the test changes once the existing comments have been addressed

lucascolley lucascolley added this to the 1.16.0 milestone 162 days ago
lucascolley lucascolley added enhancement
ev-br ENH: signal: add the infra for the array api in scipy.signal
301adbb1
ev-br MAINT: special: small simplify in _support_alternative_backends
3a4d5624
ev-br ENH: signal: array API support in convolution/correlation-like functions
b6ee2b16
ev-br TST: signal: adapt tests to array API
4fe49201
ev-br MAINT: signal: update for array_api_extra
4ced543a
ev-br MAINT: signal: linter-spotted fixes
64725e07
ev-br MAINT: signal: mark new envelope test as np-only
4b2b7156
ev-br ev-br force pushed from 159a28a3 to 41d4c3c7 162 days ago
ev-br MAINT: signal: address review comments
6d6543fe
ev-br ev-br force pushed from 41d4c3c7 to 6d6543fe 162 days ago
ev-br
j-bowhay
j-bowhay approved these changes on 2024-12-09
lucascolley
lucascolley approved these changes on 2024-12-09
lucascolley162 days ago

overall LGTM, thanks Evgeni & reviewers! Just a few comments

ev-br Update scipy/signal/tests/test_signaltools.py
6d7021ea
ev-br MAINT: signal: address review comments
3b3f9fcb
ev-br
ev-br ev-br merged 2c065e1f into main 161 days ago
j-bowhay
ev-br

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone