当前位置: 首页 > 工具软件 > nnvm > 使用案例 >

深度学习编译中间件之NNVM(十七)NNVM源代码阅读6

鲍驰
2023-12-01

参考文档

  1. 深度学习编译中间件之NNVM(十二)NNVM源代码阅读1
  2. 深度学习编译中间件之NNVM(十三)NNVM源代码阅读2
  3. 深度学习编译中间件之NNVM(十四)NNVM源代码阅读3
  4. 深度学习编译中间件之NNVM(十五)NNVM源代码阅读4
  5. 深度学习编译中间件之NNVM(十六)NNVM源代码阅读5

这篇文档将讲解和HalideIR相关的内容。

HalideIR是一个创建符号表达式和算术简化的基础模块。它从原始的Halide项目重构而来,用于TVM项目中。

工程结构

HalideIR的代码基于Halide(release_2017_05_03),由四个部分组成:

  • tvm:TVM封装器代码,用于基础数据结构
  • base:基础类型和工具
  • ir:IR数据结构
  • arithmetic:算术简化

base部分

base部分具体提供了哪些功能:

  • 代码生成调试(可选)
  • 编译时错误与异常处理
  • float16实现(去除LLVM依赖)
  • halide基础类型定义
  • halide实用工具函数定义

接下来着重讲解halide基础类型和实用工具这两个内容。

base部分将一系列类型表示为C++函数签名,这种形式拥有两个优点:

  1. 可以为Halide函数提供正确的原型,提供更好的编译时类型校验
  2. C++命名编码能为Halide函数和外部调用函数提供链接时类型校验

base部分还提供了一些实用函数

  • extract_namespaces
  • add_would_overflow:加法数值溢出判断
  • sub_would_overflow:减法数值溢出判断
  • mul_would_overflow:乘法数值溢出判断

tvm部分

tvm部分比较重要的数据结构有:

  • tvm::Node
  • tvm::NodeRef
  • tvm::ArrayNode(在DSL计算图中使用)
  • tvm::MapNode(在DSL计算图中使用)
  • tvm::IRFunctor
// 代码节选
class EXPORT Node : public std::enable_shared_from_this<Node> {
public:
    virtual const char* type_key() const = 0;
    virtual void VisitAttrs(AttrVisitor* visitor) {}
}

/*! NOdeRef是所有节点引用对象的基类 */
class NodeRef {
    using ContainerType = Node;

    inline bool operator==(const NodeRef& other) const;
    inline bool same_as(const NodeRef& other) const;

    inline bool operator<(const NodeRef& other) const;
    inline bool operator!=(const NodeRef& other) const;

    inline uint32_t type_index() const;
    inline const Node* operator->() const;

    template<typename T>
    inline const T *as() const;

    NodeRef() = default;
    explicit NodeRef(std::shared_ptr<Node> node) : node_(node) {}
    std::shared_ptr<Node> node_;
}

ir部分

tvm/HalideIR/ir/Expr.h

/** 一个处理statement node的引用计数的handle  */
struct Stmt : public IRHandle {
    Stmt() : IRHandle() {}
    Stmt(std::shared_ptr<IR::Node> n) : IRHandle(n) {}

    /** Dispatch to the correct visitor method for this node. E.g. if
     * this node is actually an Add node, then this will call
     * IRVisitor::visit(const Add *) */
    inline void accept(Internal::IRVisitor *v) const {
        static_cast<const Internal::BaseStmtNode *>(node_.get())->accept(v, *this);
    }
    /*! \brief type indicate the container type */
    using ContainerType = Internal::BaseStmtNode;
};

IR.h里面存放了深度学习需要的IR基础节点

  • 固定值(IntImm/UIntImm/FloatImm/StringImm)
  • 二进制算数运算(Add/Sub/Mul/Div/Mod/Min/Max)
  • 比较运算符()
  • 逻辑运算符(And/Or/Not)
  • 选择运算符(Select)
  • *

tvm/HalideIR/ir/IR.cpp

arithmetic部分

 类似资料: