xla
[Pallas] Support Dynamo
#6477
Merged

[Pallas] Support Dynamo #6477

alanwaketan merged 3 commits into master from alanwaketan/pallas_dynamo
alanwaketan
alanwaketan1 year ago👍 2

Summary:
This pull request enables dynamo support for custom tpu calls, e.g. ones written in Pallas.

Test Plan:
PJRT_DEVICE=TPU XLA_DISABLE_FUNCTIONALIZATION=1 python test/test_operations.py -v -k test_tpu_custom_call_pallas_add_one_dynamo

alanwaketan alanwaketan requested a review from lsy323 lsy323 1 year ago
alanwaketan alanwaketan requested a review from wonjoolee95 wonjoolee95 1 year ago
alanwaketan alanwaketan assigned alanwaketan alanwaketan 1 year ago
alanwaketan
alanwaketan1 year ago

@bdhirsh Hi Brian, I have difficulties on registering custom ops with functionalization enabled. Here is the error log, do you have any insights? Maybe the aten schema should looks something different?

root@t1v-n-f0938a8f-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call_pallas_add_one_dynamo
/workspaces/work/transformers_pt/src/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/ptxla/.local/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libc10_cuda.so: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
test_tpu_custom_call_pallas_add_one_dynamo (__main__.TestAtenXlaTensor) ... WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707181259.319496  950358 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/ptxla/.local/lib/python3.8/site-packages/libtpu/libtpu.so
I0000 00:00:1707181259.319562  950358 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1707181259.319572  950358 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
ERROR

======================================================================
ERROR: test_tpu_custom_call_pallas_add_one_dynamo (__main__.TestAtenXlaTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_operations.py", line 1943, in test_tpu_custom_call_pallas_add_one_dynamo
    compiled_add_one_pallas(output, [x], payload)
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "test/test_operations.py", line 1939, in add_one_pallas
    def add_one_pallas(output, inputs, payload):
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_functorch/aot_autograd.py", line 903, in forward
    return compiled_fn(full_args)
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 95, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/workspaces/work/pytorch/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/workspaces/work/pytorch/torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/workspaces/work/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 543, in extract_compiled_graph
    collector.run(*xla_args)
  File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 144, in run
    self.env[node] = self.run_node(node)
  File "/workspaces/work/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 431, in run_node
    result = super().run_node(n)
  File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 201, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/workspaces/work/pytorch/torch/fx/interpreter.py", line 273, in call_function
    return target(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 62, in __call__
    return super().__call__(op, mutated_args_names, kwargs)
  File "/workspaces/work/pytorch/torch/_ops.py", line 364, in __call__
    return wrapper()
  File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "/workspaces/work/pytorch/torch/_ops.py", line 360, in wrapper
    return self.dispatch(
  File "/workspaces/work/pytorch/torch/_ops.py", line 334, in dispatch
    raise NotImplementedError(
NotImplementedError: could not find kernel for HigherOrderOperator auto_functionalized at dispatch key DispatchKey.Functionalize (resolved from DispatchKey.Functionalize)

While executing %auto_functionalized : [num_users=1] = call_function[target=torch._higher_order_ops.auto_functionalize.auto_functionalized](args = (xla.tpu_custom_call_.default, [output], {output: %arg0_1, inputs: [%arg1_1], payload: {"custom_call_config": {"body": "TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==", "needs_layout_passes": true}}}), kwargs = {})
Original traceback:
  File "test/test_operations.py", line 1940, in add_one_pallas
    torch.ops.xla.tpu_custom_call_(output, inputs, payload)


----------------------------------------------------------------------
Ran 1 test in 3.185s

FAILED (errors=1)
bdhirsh
bdhirsh1 year ago👍 1

@alanwaketan - you have a custom op that mutates some of its inputs, and recently @zou3519 added an "auto-functionalize" higher-order-op that tries to automatically functionalize mutable custom ops.

I'm not sure what's causing that error. Although if you're worried about trace-time, you might be a bit better off with a hand-written C++ functionalization kernel (similar to the cod-generated ones we have for ATen).

You can find some examples to base it off of if you build pytorch locally, and inspect some of the kernels in build/aten/src/ATen/RegisterFunctionalizeEverything.cpp

alanwaketan
alanwaketan1 year ago

@alanwaketan - you have a custom op that mutates some of its inputs, and recently @zou3519 added an "auto-functionalize" higher-order-op that tries to automatically functionalize mutable custom ops.

I'm not sure what's causing that error. Although if you're worried about trace-time, you might be a bit better off with a hand-written C++ functionalization kernel (similar to the cod-generated ones we have for ATen).

You can find some examples to base it off of if you build pytorch locally, and inspect some of the kernels in build/aten/src/ATen/RegisterFunctionalizeEverything.cpp

Thanks, Brian. Will looks into this. On the other hand, I guess I can also change the semantics of my custom op to not in-place. Then all the problems should go away?

qihqi
qihqi approved these changes on 2024-02-08
alanwaketan Initial commit
a935c3ee
alanwaketan Add a mock env
282ca10f
alanwaketan Fix linters
ebccfaaf
alanwaketan alanwaketan force pushed from bc6b14c9 to ebccfaaf 1 year ago
alanwaketan
alanwaketan1 year ago

I will land this as it is and do a follow up to make the tpu_custom_call_ functional.

Thanks @qihqi for the approval.

alanwaketan alanwaketan merged ce8ee38e into master 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
Labels
Milestone