pytorch
3139e687 - [Vulkan] Add Basic Shader Registry (#91914)

Commit
1 year ago
[Vulkan] Add Basic Shader Registry (#91914) @bypass-github-export-checks We want to be able to look-up which shader to use in a registry given a particular op/algorithm name, which is what this diff enables. This is done with the newly added ```shader_registry``` map and ```look_up_shader_info``` function. After this change, Shaders can be retrieved either with the ```VK_KERNEL``` macro, which gets the Shader with a specified name directly, or with the ```VK_REGISTRY_KERNEL``` macro, which looks up what Shader should be used for a specified algorithm name in the registry. For now, the registry is empty and unused. In the next diffs in this stack, I will be adding support for registering a shader in the registry in GLSL and GLSLT + Params Yaml files. I also - Adjusted the formatting of spv.h and spv.cpp so that they are closer to what clang wants, which makes them easier to read. (proper indentation, proper order of includes, etc.) - Moved the codegen spv/registry code from at::native::vulkan to at::native::vulkan::api (since registry.cpp / .h are in ```ATen/native/vulkan/api```) Now spv.h looks like ``` #pragma once #include <ATen/native/vulkan/api/Types.h> #include <ATen/native/vulkan/api/vk_api.h> #include <c10/util/flat_hash_map.h> #include <string> namespace at { namespace native { namespace vulkan { namespace api { struct ShaderInfo; } // namespace api typedef ska::flat_hash_map<std::string, api::ShaderInfo> ShaderListing; typedef ska::flat_hash_map<std::string, std::string> RegistryKeyMap; typedef ska::flat_hash_map<std::string, RegistryKeyMap> ShaderRegistry; extern const ShaderListing shader_infos; extern ShaderRegistry shader_registry; inline const ShaderListing& get_shader_infos() { return shader_infos; } inline ShaderRegistry& get_shader_registry() { return shader_registry; } } // namespace vulkan } // namespace native } // namespace at ``` and spv.cpp looks like ``` #include <ATen/native/vulkan/api/Shader.h> #include <ATen/native/vulkan/spv.h> #include <stdint.h> #include <vector> namespace at { namespace native { namespace vulkan { namespace { const uint32_t adaptive_avg_pool2d_bin[] = { 119734787, ... }; ... const uint32_t conv2d_pw_2x2_bin[] = { 119734787, ... }; } // namespace const ShaderListing shader_infos = { {"adaptive_avg_pool2d", api::ShaderInfo( "vulkan.adaptive_avg_pool2d", adaptive_avg_pool2d_bin, 3204, {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER}, std::vector<uint32_t>(), api::StorageType::UNKNOWN, api::StorageType::UNKNOWN)}, ... {"conv2d_pw_2x2", api::ShaderInfo( "vulkan.conv2d_pw_2x2", conv2d_pw_2x2_bin, 7736, {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER}, {2, 2, 1}, api::StorageType::TEXTURE_2D, api::StorageType::TEXTURE_2D)}}; ShaderRegistry shader_registry = { }; } // namespace vulkan } // namespace native } // namespace at ``` (Full File: P594112814) Differential Revision: [D41594453](https://our.internmc.facebook.com/intern/diff/D41594453/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/91914 Approved by: https://github.com/mcr229
Author
Committer
Parents
Loading