Program Listing for File kernel.hpp

Return to documentation for file (/home/runner/work/Legion-Engine/Legion-Engine/legion/engine/core/compute/kernel.hpp)

#pragma once
#include "detail/cl_include.hpp"

#include <core/compute/buffer.hpp>
#include <core/logging/logging.hpp>
#include <variant>
#include <map>

#include <Optick/optick.h>

namespace legion::core::compute
{
    enum class block_mode : bool {
        BLOCKING,
        NON_BLOCKING
    };

    class Program;

    class Kernel
    {
    public:
        struct d2
        {
            size_type s0;
            size_type s1;
        };

        struct d3
        {
            size_type s0;
            size_type s1;
            size_type s3;
        };

        using dimension = std::variant<size_type,d2,d3>;

        Kernel(Program*, cl_kernel);

        Kernel(const Kernel& other)
            : m_refcounter(other.m_refcounter),
              m_default_mode(other.m_default_mode),
              m_paramsMap(other.m_paramsMap),
              m_prog(other.m_prog),
              m_func(other.m_func),
              m_queue(other.m_queue),
              m_global_size(other.m_global_size),
              m_local_size(other.m_local_size)
        {
            if(m_refcounter)++*m_refcounter;
        }

        Kernel(Kernel&& other) noexcept
            : m_refcounter(other.m_refcounter),
              m_default_mode(other.m_default_mode),
              m_paramsMap(std::move(other.m_paramsMap)),
              m_prog(other.m_prog),
              m_func(other.m_func),
              m_queue(other.m_queue),
              m_global_size(std::move(other.m_global_size)),
              m_local_size(std::move(other.m_local_size))
        {
            if(m_refcounter)++*m_refcounter;
        }

        Kernel& operator=(const Kernel& other)
        {
            OPTICK_EVENT();
            if (this == &other)
                return *this;
            m_refcounter = other.m_refcounter;
            m_default_mode = other.m_default_mode;
            m_paramsMap = other.m_paramsMap;
            m_prog = other.m_prog;
            m_func = other.m_func;
            m_queue = other.m_queue;
            m_global_size = other.m_global_size;
            m_local_size = other.m_local_size;
            if(m_refcounter) ++*m_refcounter;
            return *this;
        }

        Kernel& operator=(Kernel&& other) noexcept
        {
            OPTICK_EVENT();
            if (this == &other)
                return *this;
            m_refcounter = other.m_refcounter;
            m_default_mode = other.m_default_mode;
            m_paramsMap = std::move(other.m_paramsMap);
            m_prog = other.m_prog;
            m_func = other.m_func;
            m_queue = other.m_queue;
            m_global_size = std::move(other.m_global_size);
            m_local_size = std::move(other.m_local_size);
            if(m_refcounter)++*m_refcounter;
            return *this;
        }
        ~Kernel()
        {
            OPTICK_EVENT();
            if(m_refcounter)--*m_refcounter;
            if(m_refcounter && *m_refcounter == 0)
            {
                delete m_refcounter;
                clReleaseCommandQueue(m_queue);
            }
        }

        Kernel& local(size_type);

        Kernel& global(dimension);
        Kernel& global(size_type,size_type);
        Kernel& global(size_type,size_type,size_type);

        Kernel& buildBufferNames();

        Kernel& readWriteMode(buffer_type);

        Kernel& enqueueBuffer(Buffer buffer, block_mode blocking = block_mode::BLOCKING);

        Kernel& setBuffer(Buffer buffer);

        Kernel& setBuffer(Buffer buffer, const std::string& name);

        Kernel& setBuffer(Buffer buffer, cl_uint index);


        template <class T>
        Kernel& setKernelArg(T* value, const std::string& name)
        {
            OPTICK_EVENT();
            return setKernelArg(value,sizeof(T),name);
        }

        template <class T>
        Kernel& setKernelArg(T* value, cl_uint index)
        {
            OPTICK_EVENT();
            return setKernelArg(value,sizeof(T),index);
        }


        Kernel& setKernelArg(void* value, size_type size, const std::string& name);

        Kernel& setKernelArg(void* value, size_type size, cl_uint index);

        Kernel& setAndEnqueueBuffer(Buffer buffer, block_mode blocking = block_mode::BLOCKING);

        Kernel& setAndEnqueueBuffer(Buffer buffer, const std::string&, block_mode blocking = block_mode::BLOCKING);

        Kernel& setAndEnqueueBuffer(Buffer buffer, cl_uint index, block_mode blocking = block_mode::BLOCKING);

        Kernel& dispatch();

        void finish() const;


        size_type getMaxWorkSize() const;

    private:

        size_t* m_refcounter;

        buffer_type m_default_mode;
        std::map<std::string, cl_uint> m_paramsMap;
        Program* m_prog;
        cl_kernel m_func;
        cl_command_queue m_queue;

        std::tuple<std::vector<size_type>,std::vector<size_type>,size_type> parse_dimensions()
        {
            OPTICK_EVENT();
            size_type dim = 1;
            size_type* v;
            if((v = reinterpret_cast<size_type*>(std::get_if<d3>(&m_global_size))))
            {
                dim = 3;
            }
            else if ((v = reinterpret_cast<size_type*>(std::get_if<d2>(&m_global_size))))
            {
                dim  = 2;
            }
            else
            {
                v = &std::get<size_type>(m_global_size);
            }

            std::vector<size_type> l;
            for (size_type i = 0; i < dim; ++i)
            {
                l.emplace_back(m_local_size / dim);
            }

            return std::make_tuple(std::vector<size_type>(v,v+dim),l,dim);
        }

        dimension m_global_size;
        size_type m_local_size;


        //helper function to wrap parameter checking
        template <class F,class... Args>
        void param_find(F && func,std::string name)
        {
            OPTICK_EVENT();
            // lazy build buffer names
            if (m_paramsMap.empty())
            {
                buildBufferNames();
            }

            //find and do thing, otherwise print warning
            if (const auto it = m_paramsMap.find(name); it != m_paramsMap.end())
            {
                std::invoke(func,it->second);
            }
            else
            {
                log::warn("Encountered buffer with name: \"{}\", which was not found as a kernel parameter", name);
            }
        }
    };
}