@bdhirsh Hi Brian, I have difficulties on registering custom ops with functionalization enabled. Here is the error log, do you have any insights? Maybe the aten schema should looks something different?
root@t1v-n-f0938a8f-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call_pallas_add_one_dynamo
/workspaces/work/transformers_pt/src/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
/home/ptxla/.local/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libc10_cuda.so: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
test_tpu_custom_call_pallas_add_one_dynamo (__main__.TestAtenXlaTensor) ... WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707181259.319496 950358 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/ptxla/.local/lib/python3.8/site-packages/libtpu/libtpu.so
I0000 00:00:1707181259.319562 950358 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1707181259.319572 950358 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
ERROR
======================================================================
ERROR: test_tpu_custom_call_pallas_add_one_dynamo (__main__.TestAtenXlaTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test/test_operations.py", line 1943, in test_tpu_custom_call_pallas_add_one_dynamo
compiled_add_one_pallas(output, [x], payload)
File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "test/test_operations.py", line 1939, in add_one_pallas
def add_one_pallas(output, inputs, payload):
File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_functorch/aot_autograd.py", line 903, in forward
return compiled_fn(full_args)
File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 81, in g
return f(*args)
File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 95, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
return compiled_fw(args)
File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 81, in g
return f(*args)
File "/workspaces/work/pytorch/torch/_dynamo/backends/torchxla.py", line 49, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "/workspaces/work/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 543, in extract_compiled_graph
collector.run(*xla_args)
File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 144, in run
self.env[node] = self.run_node(node)
File "/workspaces/work/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 431, in run_node
result = super().run_node(n)
File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 201, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 273, in call_function
return target(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 62, in __call__
return super().__call__(op, mutated_args_names, kwargs)
File "/workspaces/work/pytorch/torch/_ops.py", line 364, in __call__
return wrapper()
File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_ops.py", line 360, in wrapper
return self.dispatch(
File "/workspaces/work/pytorch/torch/_ops.py", line 334, in dispatch
raise NotImplementedError(
NotImplementedError: could not find kernel for HigherOrderOperator auto_functionalized at dispatch key DispatchKey.Functionalize (resolved from DispatchKey.Functionalize)
While executing %auto_functionalized : [num_users=1] = call_function[target=torch._higher_order_ops.auto_functionalize.auto_functionalized](args = (xla.tpu_custom_call_.default, [output], {output: %arg0_1, inputs: [%arg1_1], payload: {"custom_call_config": {"body": "TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==", "needs_layout_passes": true}}}), kwargs = {})
Original traceback:
File "test/test_operations.py", line 1940, in add_one_pallas
torch.ops.xla.tpu_custom_call_(output, inputs, payload)
----------------------------------------------------------------------
Ran 1 test in 3.185s
FAILED (errors=1)
@alanwaketan - you have a custom op that mutates some of its inputs, and recently @zou3519 added an "auto-functionalize" higher-order-op that tries to automatically functionalize mutable custom ops.
I'm not sure what's causing that error. Although if you're worried about trace-time, you might be a bit better off with a hand-written C++ functionalization kernel (similar to the cod-generated ones we have for ATen).
You can find some examples to base it off of if you build pytorch locally, and inspect some of the kernels in build/aten/src/ATen/RegisterFunctionalizeEverything.cpp
@alanwaketan - you have a custom op that mutates some of its inputs, and recently @zou3519 added an "auto-functionalize" higher-order-op that tries to automatically functionalize mutable custom ops.
I'm not sure what's causing that error. Although if you're worried about trace-time, you might be a bit better off with a hand-written C++ functionalization kernel (similar to the cod-generated ones we have for ATen).
You can find some examples to base it off of if you build pytorch locally, and inspect some of the kernels in
build/aten/src/ATen/RegisterFunctionalizeEverything.cpp
Thanks, Brian. Will looks into this. On the other hand, I guess I can also change the semantics of my custom op to not in-place. Then all the problems should go away?
I will land this as it is and do a follow up to make the tpu_custom_call_ functional.
Thanks @qihqi for the approval.
Login to write a write a comment.
Summary:
This pull request enables dynamo support for custom tpu calls, e.g. ones written in Pallas.
Test Plan:
PJRT_DEVICE=TPU XLA_DISABLE_FUNCTIONALIZATION=1 python test/test_operations.py -v -k test_tpu_custom_call_pallas_add_one_dynamo