pytorch
bcaa93e8 - s390x simd: disable functions with out-of-bounds reads (#102266)

Commit
1 year ago
s390x simd: disable functions with out-of-bounds reads (#102266) 3 disabled functions are attempting out of bounds reads. Disable them until sleef library is fixed. <details> <summary>ASAN report</summary> ``` ================================================================= ==2030580==ERROR: AddressSanitizer: global-buffer-overflow on address 0x03ff70f54570 at pc 0x03ff6704e960 bp 0x03ffce128940 sp 0x03ffce128930 READ of size 4 at 0x03ff70f54570 thread T0 #0 0x3ff6704e95f in vgather_vf_p_vi2 /home/user/pytorch/third_party/sleef/src/arch/helpers390x_128.h:129 #1 0x3ff6704e95f in rempif /home/user/pytorch/third_party/sleef/src/libm/sleefsimdsp.c:550 #2 0x3ff6704e95f in Sleef_cosf4_u10vxe2 /home/user/pytorch/third_party/sleef/src/libm/sleefsimdsp.c:1021 #3 0x3ff67029cfb in Sleef_cosf4_u10 /home/user/pytorch/build/sleef/src/libm/disps390x_128.c:182 #4 0x3ff55d21941 in at::vec::ZVECTOR::Vectorized<float, void> at::vec::ZVECTOR::Vectorized<float, void>::mapSleef<float __vector(4) const (*)(float __vector(4)), double __vector(2) const (*)(double __ vector(2)), float, 0>(float __vector(4) const (*)(float __vector(4)), double __vector(2) const (*)(double __vector(2))) const /home/user/pytorch/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h:991 #5 0x3ff5689ad01 in at::vec::ZVECTOR::Vectorized<float, void>::cos() const /home/user/pytorch/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h:1074 #6 0x3ff5685df97 in at::vml::ZVECTOR::vcos<float>(float*, float const*, long)::{lambda(at::vec::ZVECTOR::Vectorized<float, void>)#1}::operator()(at::vec::ZVECTOR::Vectorized<float, void>) const /home/ user/pytorch/aten/src/ATen/cpu/vml.h:71 #7 0x3ff5689b691 in void at::vec::map<float, at::vml::ZVECTOR::vcos<float>(float*, float const*, long)::{lambda(at::vec::ZVECTOR::Vectorized<float, void>)#1}, 0>(at::vml::ZVECTOR::vcos<float>(float*, float const*, long)::{lambda(at::vec::ZVECTOR::Vectorized<float, void>)#1} const&, float*, float const*, long) /home/user/pytorch/aten/src/ATen/cpu/vec/functional_base.h:239 #8 0x3ff5685e0df in void at::vml::ZVECTOR::vcos<float>(float*, float const*, long) /home/user/pytorch/aten/src/ATen/cpu/vml.h:71 #9 0x3ff563fdde3 in operator() /home/user/pytorch/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp:770 #10 0x3ff5648e4a3 in operator() /home/user/pytorch/aten/src/ATen/TensorIterator.h:406 #11 0x3ff5663cae1 in callback_fn<at::TensorIteratorBase::loop_2d_from_1d<at::native::ZVECTOR::cos_kernel(at::TensorIteratorBase&)::<lambda()>::<lambda()>::<lambda(char**, const int64_t*, int64_t)> >(c onst at::native::ZVECTOR::cos_kernel(at::TensorIteratorBase&)::<lambda()>::<lambda()>::<lambda(char**, const int64_t*, int64_t)>&)::<lambda(char**, const int64_t*, int64_t, int64_t)> > /home/user/pytorch/ c10/util/FunctionRef.h:43 #12 0x3ff4d45a933 in c10::function_ref<void (char**, long const*, long, long)>::operator()(char**, long const*, long, long) const /home/user/pytorch/c10/util/FunctionRef.h:64 #13 0x3ff4d455133 in at::internal::serial_for_each(c10::ArrayRef<long>, c10::ArrayRef<long>, char**, unsigned long, c10::function_ref<void (char**, long const*, long, long)>, at::Range) /home/user/pyt orch/aten/src/ATen/TensorIteratorInternal.h:52 #14 0x3ff4d43b703 in at::TensorIteratorBase::serial_for_each(c10::function_ref<void (char**, long const*, long, long)>, at::Range) const /home/user/pytorch/aten/src/ATen/TensorIterator.cpp:777 #15 0x3ff4d43ab59 in at::TensorIteratorBase::for_each(c10::function_ref<void (char**, long const*, long, long)>, long) /home/user/pytorch/aten/src/ATen/TensorIterator.cpp:749 #16 0x3ff5648e851 in for_each<at::native::ZVECTOR::cos_kernel(at::TensorIteratorBase&)::<lambda()>::<lambda()>::<lambda(char**, const int64_t*, int64_t)> > /home/user/pytorch/aten/src/ATen/TensorItera tor.h:421 #17 0x3ff563fe5f9 in operator() /home/user/pytorch/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp:770 #18 0x3ff56400915 in operator() /home/user/pytorch/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp:770 #19 0x3ff56400f1d in at::native::ZVECTOR::cos_kernel(at::TensorIteratorBase&) /home/user/pytorch/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp:770 #20 0x3ff4f303007 in void at::native::DispatchStub<void (*)(at::TensorIteratorBase&), at::native::cos_stub>::operator()<at::native::structured_cos_out&>(c10::DeviceType, at::native::structured_cos_out &) /home/user/pytorch/aten/src/ATen/native/DispatchStub.h:158 #21 0x3ff4f2edb3f in at::native::structured_cos_out::impl(at::Tensor const&, at::Tensor const&) /home/user/pytorch/aten/src/ATen/native/UnaryOps.cpp:330 #22 0x3ff526ef739 in wrapper_CPU_cos /home/user/pytorch/build/aten/src/ATen/RegisterCPU.cpp:4307 #23 0x3ff52c651d9 in operator() /home/user/pytorch/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13 #24 0x3ff52c651d9 in call /home/user/pytorch/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:463 #25 0x3ff5076df2f in at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&>(void*, c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) /home/user/pytorch/aten/src/ATen/core /boxing/KernelFunction_impl.h:50 #26 0x3ff5009a93f in at::Tensor c10::KernelFunction::call<at::Tensor, at::Tensor const&>(c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&) const /home/user/pytorch/aten/src/ATen/core /boxing/KernelFunction_impl.h:103 #27 0x3ff5009a93f in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)> const&, at::Tensor const&) const /home/user/pytorch/aten/s rc/ATen/core/dispatch/Dispatcher.h:639 #28 0x3ff5009a93f in c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)>::call(at::Tensor const&) const /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:487 #29 0x3ff5009a93f in at::_ops::cos::call(at::Tensor const&) /home/user/pytorch/build/aten/src/ATen/Operators_0.cpp:2215 #30 0x3ff7d813741 in at::Tensor::cos() const /home/user/pytorch/build/aten/src/ATen/core/TensorBody.h:2107 #31 0x3ff7dc0f2b7 in operator() /home/user/pytorch/torch/csrc/autograd/generated/python_torch_functions_2.cpp:2953 #32 0x3ff7dc0faf7 in THPVariable_cos /home/user/pytorch/torch/csrc/autograd/generated/python_torch_functions_2.cpp:2955 #33 0x3ffa5ef5ae1 in cfunction_call Objects/methodobject.c:543 #34 0x3ffa5e843f3 in _PyObject_Call Objects/call.c:305 #35 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #36 0x3ffa5feb50d in do_call_core Python/ceval.c:5915 #37 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #38 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #39 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #40 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #41 0x3ffa5e841fb in PyVectorcall_Call Objects/call.c:255 #42 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #43 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #44 0x3ff7f87a393 in torch::impl::dispatch::PythonKernelHolder::operator()(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) /home/user/pytorch/ torch/csrc/utils/python_dispatch.cpp:175 #45 0x3ff7f8871a7 in c10::BoxedKernel::makeFromFunctor<torch::impl::dispatch::PythonKernelHolder>(std::unique_ptr<torch::impl::dispatch::PythonKernelHolder, std::default_delete<torch::impl::dispatch:: PythonKernelHolder> >)::{lambda(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)#1}::operator()(c10::OperatorKernel*, c10::Op eratorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/boxing/BoxedKernel_impl.h:87 #46 0x3ff7f887261 in c10::BoxedKernel::makeFromFunctor<torch::impl::dispatch::PythonKernelHolder>(std::unique_ptr<torch::impl::dispatch::PythonKernelHolder, std::default_delete<torch::impl::dispatch:: PythonKernelHolder> >)::{lambda(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)#1}::_FUN(c10::OperatorKernel*, c10::Operator Handle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) /home/user/pytorch/aten/src/ATen/core/boxing/BoxedKernel_impl.h:86 #47 0x3ff7e0d10ab in c10::BoxedKernel::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/b oxing/BoxedKernel_impl.h:41 #48 0x3ff7e0d1459 in c10::KernelFunction::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/cor e/boxing/KernelFunction_impl.h:43 #49 0x3ff7f876421 in c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:6 91 #50 0x3ff4d22bcdd in c10::OperatorHandle::callBoxed(std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:417 #51 0x3ff65a092d5 in c10::OperatorHandle::callBoxed(std::vector<c10::IValue, std::allocator<c10::IValue> >&) const /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:421 #52 0x3ff65a05641 in operator() /home/user/pytorch/torch/csrc/jit/runtime/register_c10_ops.cpp:15 #53 0x3ff65a08cb5 in __invoke_impl<void, torch::jit::(anonymous namespace)::createOperatorFromC10(const c10::OperatorHandle&)::<lambda(torch::jit::Stack&)>&, std::vector<c10::IValue, std::allocator<c1 0::IValue> >&> /usr/lib/gcc/s390x-ibm-linux-gnu/11/include/g++-v11/bits/invoke.h:61 #54 0x3ff65a0897b in __invoke_r<void, torch::jit::(anonymous namespace)::createOperatorFromC10(const c10::OperatorHandle&)::<lambda(torch::jit::Stack&)>&, std::vector<c10::IValue, std::allocator<c10:: IValue> >&> /usr/lib/gcc/s390x-ibm-linux-gnu/11/include/g++-v11/bits/invoke.h:111 #55 0x3ff65a084e1 in _M_invoke /usr/lib/gcc/s390x-ibm-linux-gnu/11/include/g++-v11/bits/std_function.h:290 #56 0x3ff7eb2cb21 in std::function<void (std::vector<c10::IValue, std::allocator<c10::IValue> >&)>::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >&) const /usr/lib/gcc/s390x-ibm-lin ux-gnu/11/include/g++-v11/bits/std_function.h:590 #57 0x3ff7eb1b659 in torch::jit::Operation::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >&) /home/user/pytorch/aten/src/ATen/core/stack.h:41 #58 0x3ff7eb08449 in torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args, pybind11:: kwargs const&, c10::optional<c10::DispatchKey>) /home/user/pytorch/torch/csrc/jit/python/pybind_utils.cpp:764 #59 0x3ff7eb09d85 in torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args, pybind11::kwargs const&, bool, c10::optional<c10::DispatchKey>) /home/user/pytorch/torch/csrc/jit/python/pybind_utils.cpp:829 #60 0x3ff7e573eb9 in operator() /home/user/pytorch/torch/csrc/jit/python/init.cpp:1549 #61 0x3ff7e6728dd in call_impl<pybind11::object, torch::jit::initJITBindings(PyObject*)::<lambda(const string&, const string&)>::<lambda(pybind11::args, pybind11::kwargs)>&, 0, 1, pybind11::detail::vo id_type> /home/user/pytorch/third_party/pybind11/include/pybind11/cast.h:1439 #62 0x3ff7e64312f in call<pybind11::object, pybind11::detail::void_type, torch::jit::initJITBindings(PyObject*)::<lambda(const string&, const string&)>::<lambda(pybind11::args, pybind11::kwargs)>&> /h ome/user/pytorch/third_party/pybind11/include/pybind11/cast.h:1408 #63 0x3ff7e5da259 in operator() /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:249 #64 0x3ff7e5da441 in _FUN /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:224 #65 0x3ff7d317a1f in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:929 #66 0x3ffa5ef5ae1 in cfunction_call Objects/methodobject.c:543 #67 0x3ffa5e843f3 in _PyObject_Call Objects/call.c:305 #68 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #69 0x3ffa5feb50d in do_call_core Python/ceval.c:5915 #70 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #71 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #72 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #73 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #74 0x3ffa5e83d1f in _PyObject_FastCallDictTstate Objects/call.c:142 #75 0x3ffa5e84937 in _PyObject_Call_Prepend Objects/call.c:431 #76 0x3ffa5f2f577 in slot_tp_call Objects/typeobject.c:7494 #77 0x3ffa5e843f3 in _PyObject_Call Objects/call.c:305 #78 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #79 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #80 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #81 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #82 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #83 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #84 0x3ffa5fd76a3 in _PyObject_VectorcallTstate Include/cpython/abstract.h:114 #85 0x3ffa5fd772f in PyObject_Vectorcall Include/cpython/abstract.h:123 #86 0x3ffa5feb289 in call_function Python/ceval.c:5891 #87 0x3ffa5fe5c3b in _PyEval_EvalFrameDefault Python/ceval.c:4213 #88 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #89 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #90 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #91 0x3ffa5e841fb in PyVectorcall_Call Objects/call.c:255 #92 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #93 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #94 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #95 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #96 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #97 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #98 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #99 0x3ffa5e841fb in PyVectorcall_Call Objects/call.c:255 #100 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #101 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #102 0x3ff7f87a393 in torch::impl::dispatch::PythonKernelHolder::operator()(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) /home/user/pytorch /torch/csrc/utils/python_dispatch.cpp:175 #103 0x3ff7f8871a7 in c10::BoxedKernel::makeFromFunctor<torch::impl::dispatch::PythonKernelHolder>(std::unique_ptr<torch::impl::dispatch::PythonKernelHolder, std::default_delete<torch::impl::dispatch: :PythonKernelHolder> >)::{lambda(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)#1}::operator()(c10::OperatorKernel*, c10::O peratorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/boxing/BoxedKernel_impl.h:87 #104 0x3ff7f887261 in c10::BoxedKernel::makeFromFunctor<torch::impl::dispatch::PythonKernelHolder>(std::unique_ptr<torch::impl::dispatch::PythonKernelHolder, std::default_delete<torch::impl::dispatch: :PythonKernelHolder> >)::{lambda(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)#1}::_FUN(c10::OperatorKernel*, c10::Operato rHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) /home/user/pytorch/aten/src/ATen/core/boxing/BoxedKernel_impl.h:86 #105 0x3ff7e0d10ab in c10::BoxedKernel::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/ boxing/BoxedKernel_impl.h:41 #106 0x3ff7e0d1459 in c10::KernelFunction::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/co re/boxing/KernelFunction_impl.h:43 #107 0x3ff7f876421 in c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h: 691 #108 0x3ff4d22bcdd in c10::OperatorHandle::callBoxed(std::vector<c10::IValue, std::allocator<c10::IValue> >*) const /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:417 #109 0x3ff65a092d5 in c10::OperatorHandle::callBoxed(std::vector<c10::IValue, std::allocator<c10::IValue> >&) const /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:421 #110 0x3ff65a05641 in operator() /home/user/pytorch/torch/csrc/jit/runtime/register_c10_ops.cpp:15 #111 0x3ff65a08cb5 in __invoke_impl<void, torch::jit::(anonymous namespace)::createOperatorFromC10(const c10::OperatorHandle&)::<lambda(torch::jit::Stack&)>&, std::vector<c10::IValue, std::allocator<c 10::IValue> >&> /usr/lib/gcc/s390x-ibm-linux-gnu/11/include/g++-v11/bits/invoke.h:61 #112 0x3ff65a0897b in __invoke_r<void, torch::jit::(anonymous namespace)::createOperatorFromC10(const c10::OperatorHandle&)::<lambda(torch::jit::Stack&)>&, std::vector<c10::IValue, std::allocator<c10: :IValue> >&> /usr/lib/gcc/s390x-ibm-linux-gnu/11/include/g++-v11/bits/invoke.h:111 #113 0x3ff65a084e1 in _M_invoke /usr/lib/gcc/s390x-ibm-linux-gnu/11/include/g++-v11/bits/std_function.h:290 #114 0x3ff7eb2cb21 in std::function<void (std::vector<c10::IValue, std::allocator<c10::IValue> >&)>::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >&) const /usr/lib/gcc/s390x-ibm-li nux-gnu/11/include/g++-v11/bits/std_function.h:590 #115 0x3ff7eb1b659 in torch::jit::Operation::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >&) /home/user/pytorch/aten/src/ATen/core/stack.h:41 #116 0x3ff7eb08449 in torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args, pybind11: :kwargs const&, c10::optional<c10::DispatchKey>) /home/user/pytorch/torch/csrc/jit/python/pybind_utils.cpp:764 #117 0x3ff7eb09d85 in torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args, pybind11::kwargs const&, bool, c10::optional<c10::DispatchKey>) /home/user/pytorch/torch/csrc/jit/python/pybind_utils.cpp:829 #118 0x3ff7e573eb9 in operator() /home/user/pytorch/torch/csrc/jit/python/init.cpp:1549 #119 0x3ff7e6728dd in call_impl<pybind11::object, torch::jit::initJITBindings(PyObject*)::<lambda(const string&, const string&)>::<lambda(pybind11::args, pybind11::kwargs)>&, 0, 1, pybind11::detail::v oid_type> /home/user/pytorch/third_party/pybind11/include/pybind11/cast.h:1439 #120 0x3ff7e64312f in call<pybind11::object, pybind11::detail::void_type, torch::jit::initJITBindings(PyObject*)::<lambda(const string&, const string&)>::<lambda(pybind11::args, pybind11::kwargs)>&> / home/user/pytorch/third_party/pybind11/include/pybind11/cast.h:1408 #121 0x3ff7e5da259 in operator() /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:249 #122 0x3ff7e5da441 in _FUN /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:224 #123 0x3ff7d317a1f in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:929 #124 0x3ffa5ef5ae1 in cfunction_call Objects/methodobject.c:543 #125 0x3ffa5e843f3 in _PyObject_Call Objects/call.c:305 #126 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #127 0x3ffa5feb50d in do_call_core Python/ceval.c:5915 #128 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #129 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #130 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #131 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #132 0x3ffa5e83d1f in _PyObject_FastCallDictTstate Objects/call.c:142 #133 0x3ffa5e84937 in _PyObject_Call_Prepend Objects/call.c:431 #134 0x3ffa5f2f577 in slot_tp_call Objects/typeobject.c:7494 #135 0x3ffa5e843f3 in _PyObject_Call Objects/call.c:305 #136 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #137 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #138 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #139 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #140 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #141 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #142 0x3ffa5e87d2b in _PyObject_VectorcallTstate Include/cpython/abstract.h:114 #143 0x3ffa5e882dd in method_vectorcall Objects/classobject.c:83 #144 0x3ffa5e836d3 in _PyObject_VectorcallTstate Include/cpython/abstract.h:114 #145 0x3ffa5e84b6f in _PyObject_CallFunctionVa Objects/call.c:485 #146 0x3ffa5e84f2d in callmethod Objects/call.c:557 #147 0x3ffa5e85039 in PyObject_CallMethod Objects/call.c:577 #148 0x3ff7f7efa05 in torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<pybind11::handle>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName) /home/user/py torch/torch/csrc/utils/python_arg_parser.cpp:338 #149 0x3ff7eb09b67 in torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args, pybind11::kwargs const&, bool, c10::optional<c10::DispatchKey>) /home/user/pytorch/torch/csrc/jit/python/pybind_utils.cpp:827 #150 0x3ff7e573eb9 in operator() /home/user/pytorch/torch/csrc/jit/python/init.cpp:1549 #151 0x3ff7e6728dd in call_impl<pybind11::object, torch::jit::initJITBindings(PyObject*)::<lambda(const string&, const string&)>::<lambda(pybind11::args, pybind11::kwargs)>&, 0, 1, pybind11::detail::v oid_type> /home/user/pytorch/third_party/pybind11/include/pybind11/cast.h:1439 #152 0x3ff7e64312f in call<pybind11::object, pybind11::detail::void_type, torch::jit::initJITBindings(PyObject*)::<lambda(const string&, const string&)>::<lambda(pybind11::args, pybind11::kwargs)>&> / home/user/pytorch/third_party/pybind11/include/pybind11/cast.h:1408 #153 0x3ff7e5da259 in operator() /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:249 #154 0x3ff7e5da441 in _FUN /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:224 #155 0x3ff7d317a1f in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /home/user/pytorch/third_party/pybind11/include/pybind11/pybind11.h:929 #156 0x3ffa5ef5ae1 in cfunction_call Objects/methodobject.c:543 #157 0x3ffa5e843f3 in _PyObject_Call Objects/call.c:305 #158 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #159 0x3ffa5feb50d in do_call_core Python/ceval.c:5915 #160 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #161 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #162 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #163 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #164 0x3ffa5e83d1f in _PyObject_FastCallDictTstate Objects/call.c:142 #165 0x3ffa5e84937 in _PyObject_Call_Prepend Objects/call.c:431 #166 0x3ffa5f2f577 in slot_tp_call Objects/typeobject.c:7494 #167 0x3ffa5e84027 in _PyObject_MakeTpCall Objects/call.c:215 #168 0x3ffa5fd767b in _PyObject_VectorcallTstate Include/cpython/abstract.h:112 #169 0x3ffa5fd772f in PyObject_Vectorcall Include/cpython/abstract.h:123 #170 0x3ffa5feb289 in call_function Python/ceval.c:5891 #171 0x3ffa5fe5ad1 in _PyEval_EvalFrameDefault Python/ceval.c:4181 #172 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #173 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #174 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #175 0x3ffa5fd76a3 in _PyObject_VectorcallTstate Include/cpython/abstract.h:114 #176 0x3ffa5fd772f in PyObject_Vectorcall Include/cpython/abstract.h:123 #177 0x3ffa5feb289 in call_function Python/ceval.c:5891 #178 0x3ffa5fe5c3b in _PyEval_EvalFrameDefault Python/ceval.c:4213 #179 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #180 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #181 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #182 0x3ffa5e8427f in PyVectorcall_Call Objects/call.c:267 #183 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #184 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #185 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #186 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #187 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #188 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #189 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #190 0x3ffa5e841fb in PyVectorcall_Call Objects/call.c:255 #191 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #192 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #193 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #194 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #195 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #196 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #197 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #198 0x3ffa5e841fb in PyVectorcall_Call Objects/call.c:255 #199 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #200 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #201 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #202 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #203 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #204 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #205 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #206 0x3ffa5e841fb in PyVectorcall_Call Objects/call.c:255 #207 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #208 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #209 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #210 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #211 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #212 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #213 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #214 0x3ffa5e83d1f in _PyObject_FastCallDictTstate Objects/call.c:142 #215 0x3ffa5e84937 in _PyObject_Call_Prepend Objects/call.c:431 #216 0x3ffa5f2f577 in slot_tp_call Objects/typeobject.c:7494 #217 0x3ffa5e843f3 in _PyObject_Call Objects/call.c:305 #218 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #219 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #220 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #221 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #222 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #223 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #224 0x3ffa5fd76a3 in _PyObject_VectorcallTstate Include/cpython/abstract.h:114 #225 0x3ffa5fd772f in PyObject_Vectorcall Include/cpython/abstract.h:123 #226 0x3ffa5feb289 in call_function Python/ceval.c:5891 #227 0x3ffa5fe5b21 in _PyEval_EvalFrameDefault Python/ceval.c:4198 #228 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #229 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #230 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #231 0x3ffa5e8427f in PyVectorcall_Call Objects/call.c:267 #232 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #233 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #234 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #235 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #236 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #237 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #238 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #239 0x3ffa5e8427f in PyVectorcall_Call Objects/call.c:267 #240 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #241 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #242 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #243 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #244 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #245 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #246 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #247 0x3ffa5e8427f in PyVectorcall_Call Objects/call.c:267 #248 0x3ffa5e84347 in _PyObject_Call Objects/call.c:290 #249 0x3ffa5e84483 in PyObject_Call Objects/call.c:317 #250 0x3ffa5feb7cf in do_call_core Python/ceval.c:5943 #251 0x3ffa5fe6019 in _PyEval_EvalFrameDefault Python/ceval.c:4277 #252 0x3ffa5fd7aed in _PyEval_EvalFrame Include/internal/pycore_ceval.h:46 #253 0x3ffa5fe8ba9 in _PyEval_Vector Python/ceval.c:5065 #254 0x3ffa5e8459b in _PyFunction_Vectorcall Objects/call.c:342 #255 0x3ffa5e8427f in PyVectorcall_Call Objects/call.c:267 0x03ff70f54570 is located 0 bytes to the right of global variable 'Sleef_rempitabsp' defined in '/home/user/pytorch/third_party/sleef/src/libm/rempitab.c:986:34' (0x3ff70f53f00) of size 1648 SUMMARY: AddressSanitizer: global-buffer-overflow /home/user/pytorch/third_party/sleef/src/arch/helpers390x_128.h:129 in vgather_vf_p_vi2 Shadow bytes around the buggy address: 0x10007fee1ea850: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea860: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea870: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea880: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea890: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 =>0x10007fee1ea8a0: 00 00 00 00 00 00 00 00 00 00 00 00 00 00[f9]f9 0x10007fee1ea8b0: f9 f9 f9 f9 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea8c0: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea8d0: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea8e0: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 0x10007fee1ea8f0: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 Shadow byte legend (one shadow byte represents 8 application bytes): Addressable: 00 Partially addressable: 01 02 03 04 05 06 07 Heap left redzone: fa Freed heap region: fd Stack left redzone: f1 Stack mid redzone: f2 Stack right redzone: f3 Stack after return: f5 Stack use after scope: f8 Global redzone: f9 Global init order: f6 Poisoned by user: f7 Container overflow: fc Array cookie: ac Intra object redzone: bb ASan internal: fe Left alloca redzone: ca Right alloca redzone: cb Shadow gap: cc ==2030580==ABORTING ``` </details> It reproduces when running `pytest -v test/test_ops.py -k test_python_ref__refs_cos_cpu_bfloat16` under address sanitizer on s390x. See also: https://github.com/shibatch/sleef/issues/464 Pull Request resolved: https://github.com/pytorch/pytorch/pull/102266 Approved by: https://github.com/malfet
Committer
Parents
Loading