当前位置: 首页 > 工具软件 > Relay > 使用案例 >

【TVM帮助文档学习】Relay的代数数据类型

范玄裳
2023-12-01

本文翻译自Algebraic Data Types in Relay — tvm 0.9.dev0 documentation
代数数据类型(ADT)是函数式编程语言的主要特征,尤其是那些派生自ML的语言,因为它们表示数据结构的方式在编写递归计算时很容易推理。因为递归是Relay中控制流的主要机制之一,所以为了最优地表达循环和其他必须使用递归实现的控制流结构,引入ADT是非常重要。

ADT定义和匹配

注意:目前文本格式不支持ADT。这里的语法是基于其他语言中的ADT推测的。
ADT可以理解为类C语言中枚举和结构类型的通用版本。与C struct类似:ADT实例是指定类型字段的容器,但类型系统允许同一类型以系统方式对不同的字段可能分组进行编码,就像C的enum类型是由用户指定的有限可能值集定义的
具体来说,ADT被定义为一组命名构造函数,每个构造函数都是一个函数,它接受指定类型的值作为参数,并返回一个命名的ADT实例。ADT实例只包含传递给生成自己的构造函数的参数值
ADT值在析构之前是不透明的,允许构造函数的参数再次被访问并用于计算新值。因为一个特定的ADT可以有多个具有不同签名的构造函数,所以通常需要在不同的可能构造函数上分支,从而产生ADT的匹配语法。因此,ADT有时被称为“带标记的联合”,因为ADT实例是由用于生成它的构造函数的名称标记的,以后可以基于标记进行检查。
因为每个ADT都有一个有限的构造函数集,所以很容易确定处理ADT实例的函数是否处理所有可能的情况。特别是,与c语言中的联合类型相比,类型系统可以确保在解析ADT实例时,在所有情况下都正确地分配类型。因此,对ADT进行推理通常很容易。
实现细节:Relay的ADT定义是全局的,存储在模块中,类似于全局函数定义。ADT的名称实际上是一个全局类型变量(就像全局函数名称是一个全局变量一样)。模块持有ADT名称(全局类型变量)到该ADT的构造函数列表的映射。
下面是一个定义ADT的简单例子,并通过匹配表达式在函数中使用它:

# Defines an ADT named "Numbers"
data Numbers {
  Empty : () -> Numbers
  Single : (Tensor[(), int32]) -> Numbers
  Pair : (Tensor[(), int32], Tensor[(), int32]) -> Numbers
}
# A Numbers value can be produced using an Empty, Single, or Pair
# constructor, each with a signature given above

def @sum(%n : Numbers[]) -> Tensor[(), int32] {
   # The match expression branches on the constructor that was
   # used to produce %n. The variables in each case are bound
   # if the constructor matches that used for %n
   match(%n) {
     case Empty() { 0 }
     case Single(x) { x }
     case Pair(x, y) { x + y }
   }
}

@sum(Empty())    # evaluates to 0
@sum(Single(3))  # evaluates to 3
@sum(Pair(5, 6)) # evaluates to 11

注意,ADT是通过名称标识的,这意味着两个具有相同构造函数的ADT从类型检查器的角度来看仍然是不同的数据类型。

# structurally identical constructors to Numbers
data Numbers2 {
  Empty2 : () -> Numbers2
  Single2 : (Tensor[(), int32]) -> Numbers2
  Pair2 : (Tensor[(), int32], Tensor[(), int32]) -> Numbers2
}

# the below results in a type error because Numbers2
# is a distinct type from Numbers
# fn() { @sum(Empty2()) }

ADT类型检查和多态

本节将更详细地介绍ADT的类型。与函数一样,ADT可以是多态的,并且可以接受类型参数,这是ADT最复杂的特性。
例如,函数式编程语言中通用标准ADT之一是可选类型,定义如下:

# a is a type parameter
data Optional<a> {
  None : () -> Optional
  Some : (a) -> Optional
}

可选类型通常用作任何涉及查询数据结构的操作的返回类型(如果找到值则返回Some(v),如果没有则返回None)。在定义中接受类型参数允许在各种情况下使用相同的可选类型,而不必为每种可能类型都定义一个独有的ADT。
然而,必须确保类型系统可以区分不同的选项的类型,因为如果一个函数期望一个包含Tensor[(),int32]的选项,而不是一个Tensor[(3,4),float32]的选项,就会违反类型安全。正如这个例子所暗示的那样,一个ADT实例因此被赋予了一个类型,该类型包含该实例的具体类型参数,以确保信息被保留。用下面的例子来说明:

# the signature for option indicates the type argument
def @inc_scalar(%opt : Optional[Tensor[(), int32]]) -> Tensor[(), int32] {
  match(%opt) {
    case None() { 1 }
    case Some(%s) { %s + 1 }
  }
}

def @main() {
  let %one : Optional[Tensor[(), int32]] = Some(1);
  let %big : Optional[Tensor[(10, 10), float32]]
    = Some(Constant(1, (10, 10), float32));
  let %two = inc_scalar(%one);
  # let %bigger = inc_scalar(%big); # type system rejects
  # None does not take an argument so it can always implicitly
  # be given the correct type arguments
  let %z = inc_scalar(None());
  ()
}

上面例子中带类型参数说明的语法(例如,Optional[(), int32])被称为“类型调用”,将多态的ADT定义视为类型级函数(接受类型参数并返回类型,即ADT)。任何出现在类型说明或函数签名中的ADT都必须用类型参数进行说明(非多态ADT必须在没有参数的类型调用中)。
因此,一般来说,如果接收的类型形数为T1,…, Tn的构造函数C是ADT D的构造函数,它接受类型实参为v1,…, vn(其中T1,…, Tn可以包含任何v1,…, vn),那么C的类型是fun<v1, ... , vn>(T1, ... , Tn) -> D[v1, ... , vn]。这意味着构造函数的类型与普通函数类似,因此出现在调用节点中,可以传递给其他函数或由其他函数返回。特别地,上面的一些示例具有签名 fun(a) -> Optional[a],而None具有签名 signature fun() -> Optional[a]。

ADT递归

ADT定义允许递归,也就是说,名为D的ADT定义可以假定类型D存在,并将其用作构造函数的参数。递归允许adt表示复杂的结构,如列表或树;它是函数式编程中adt强大功能的来源,因为一个设计得当的数据结构可以使用递归函数精确地表示计算变得容易。
许多常用的adt都涉及递归;其中一些是在Common ADT Uses中给出的。下面的例子中,我们将检查ADT列表,它在函数式语言中无处不在:

data List<a> {
   Nil : () -> List
   Cons : (a, List[a]) -> List
}

(注意,即使是在构造函数中,对List的递归引用也是在类型调用中包装的)
上面的定义意味着特定类型的值的列表可以用嵌套的Cons构造函数表示,一直到列表的末尾,而末尾可以用Nil(表示空列表)表示。
以这种方式表示的列表可以很容易地递归处理。例如,下面的函数对一个整数列表求和:

def @list_sum(%l : List[Tensor[(), int32]]) -> Tensor[(), int32] {
  match(%l) {
    case Nil() { 0 }
    # add the head of the list to the sum of the tail
    case Cons(%h, %t) { %h + @list_sum(%t) }
  }
}

恰巧许多处理列表的递归函数就像刚才给出的共享结构一样,可以分解为通用的、易于使用的函数,这些函数将在 Common ADT Uses中讨论。

表达式匹配中的模式匹配

与其他函数式语言一样,Relay中的匹配表达式能够进行更通用的模式匹配,而不是被解析的值的数据类型的每个构造函数只涵盖一种情况。
特别是,匹配项中的模式可以递归地构建:
构造函数模式匹配特定的ADT构造函数。如果一个值匹配构造函数,构造函数的每个参数都将根据嵌套模式进行匹配。
通配符模式将匹配任何值,而且不会绑定到变量。
变量模式将匹配任何值,并将其绑定到一个局部变量,其作用域为match子句。
在上面@list_sum的简单例子中,第一个匹配例子有一个Nil构造函数模式(没有嵌套参数),第二个匹配例子有一个Cons构造函数模式,它为Cons的每个参数使用变量模式。
下面的例子使用了通配符模式来忽略Cons的一个参数:

def @first<a>(%l : List[a]) -> Optional[a] {
  match(%l) {
    case Nil() { None() }
    case Cons(%h, _) { Some(%h) } # list tail is unused and ignored
  }
}

这里,构造函数模式嵌套在另一个构造函数模式中,以避免嵌套列表选项的匹配表达式。顶层通配符模式也用于处理所有不匹配第一个子句的情况:

def @second_opt<a>(%ll : Optional[List[a]]) -> Optional[a] {
  match(%ll) {
    # we only need the second member of the list if there is one
    case Some(Cons(_, Cons(%s, _))) { Some(%s) }
    case _ { None() }
  }
}

# @second_opt(Some(Cons(1, Nil()))) evaluates to None()
# @second_opt(Some(Cons(1, Cons(2, Nil())))) evaluates to Some(2)
# @second_opt(Some(Nil())) evaluates to None()
# @second_opt(None()) evaluates to None()

注意匹配表达式按照case条件顺序检查其模式(即从上到下检查满足哪个case):第一个和输入匹配的case字句将会被执行。在这里,顶层变量模式绑定了整个输入值:

def @match_order_beware<a>(%l : List[a]) -> List[a] {
  match(%l) {
    case %v { %v }
    # the above matches everything so neither of these runs
    case Cons(%h, %t) { Cons(%h, @match_order_beware(%t)) }
    case Nil() { Nil() }
  }
}

Common ADT Uses

在函数式编程语言中,某些ADT为编写公共程序提供了有用的工具。参数多态性和高阶函数允许这些ADT易于重用,并允许泛型函数在常见情况下对它们进行操作。Relay包括某些预定义ADT的“前奏”,以及其对应于其他语言中不可或缺的ADT的函数
在类型检查ADT和多态性时定义的选项类型就是这样的一个ADT,每当函数只能在特定情况下返回值时,就会使用它。拥有选项类型允许类型系统跟踪哪些函数总是返回某一类型的值,而不是返回该类型的选项,确保始终显式检查任何选项(另一种解决问题的方法是返回空指针或抛出异常)。
列表(在adt递归中定义的)可以由泛型函数操作,操作方式类似于Python中的列表推导式和某些库函数。下面是列表迭代的常用函数,它们包含在Relay的Prelude中。(这些都在函数式编程文献中得到了广泛的描述,我们不打算在本文档中赘述这些工作。)

# Map: for [h1, h2, ..., hn] returns [f(h1), f(h2), ..., f(hn)]
def @map<a, b>(%f : fn(a) -> b, %l : List[a]) -> List[b] {
  match(%l) {
    case Nil() { Nil() }
    case Cons(%h, %t) { Cons(%f(%h), @map(%f, %t)) }
  }
}

# Left fold: for [h1, h2, ..., hn] returns f(...(f(f(z, h1), h2)...), hn)
def @foldl<a, b>(%f : fn(b, a) -> b, %z : b, %l : List[a]) -> b {
  match(%l) {
    case Nil() { %z }
    case Cons(%h, %t) { @foldl(%f, %f(%z, %h), %t) }
  }
}

# Right fold: for [h1, h2, ..., hn] returns f(h1, f(h2, f(..., (f(hn, z)...)
def @foldr<a, b>(%f : fn(a, b) -> b, %z : b, %l : List[a] -> b {
  match(%l) {
    case Nil() { %z }
    case Cons(%h, %t) { %f(%h, @foldr(%f, %z, %t)) }
  }
}

使用这些迭代结构,列表上的许多常见操作可以更紧凑地表示。例如,下面的映射对列表中的所有成员进行加倍:

# directly written
def @double(%l : List[Tensor[(), int32]]) -> List[Tensor[(), int32]] {
  match(%l) {
    case Nil() { Nil() }
    case Cons(%h, %t) { Cons(%h * 2, @double(%t)) }
  }
}

# map takes care of the recursion
@map(fn(%i) { %i * 2 }, %l)

下面的右折叠连接两个列表:

# directly written
def @concat<a>(%l1 : List[a], %l2 : List[a]) -> List[a] {
  match(%l1) {
    case Nil() { %l2 }
    case Cons(%h, %t) { Cons(%h, @concat(%t, %l2) }
  }
}

# foldr takes care of the recursion
@foldr(fn(%h, %z) { Cons(%h, %z) }, %l2, %l1)

下面的左折叠平坦化列表中的子列表(使用连接):

# directly written
def @flatten<a>(%ll : List[List[a]]) -> List[a] {
  match(%ll) {
    case Cons(%h, %t) { @concat(%h, @flatten(%t)) }
    case Nil() { Nil() }
  }

# foldl takes care of the recursion
@foldl(@concat, Nil(), %ll)

这些迭代构造可以直接使用Relay基本语句实现,而且更容易定义(对于更多的数据类型,比如树),无需使用内置的复杂结构(例如MXNet中的“foreach”)。ADT及其可扩展性允许在Relay中表达广泛的迭代和数据结构,而且可以直接被类型系统支持,无需修改语言实现。

使用ADT实现神经网络

在这篇2015年的博客文章(http://colah.github.io/posts/2015-09-NN-Types-FP/)中,Christopher Olah指出,许多神经网络可以很容易地用常用的函数编程结构来表达。使用Relay的adt可以直接在TVM中实现博文中示例。
首先,假设我们有一个训练好的循环神经网络(RNN)单元对应的函数,它接受的输入有过去的状态和输入值,返回新的状态和输出值。在Relay中,这将包含以下签名:

@cell : fn<state_type, in_type, out_type>(state_type, in_type) -> (state_type, out_type)

我们以ReLU单元为例,训练后的版本如下:

def @linear(%x, %w, %b) { %w*%x + %b }

def @relu_cell(%w, # weights
               %b, # offsets
               %s, # state
               %x  # input
) {
  let %x2 = @linear(%x, %w.0, %b.0);
  let %s2 = @linear(%s, %w.1, %b.1);
  # doesn't change the state
  (%s, nn.relu(%x2 + %s2))
}

# this is a higher-order function because it returns a closure
def @trained_cell(%w, %b) {
  fn(%x, %h) { @relu_cell(%w, %b, %x, %h) }
}

按照Olah的例子,我们可以用下面的左折叠编码一个输入序列(列表):

def @encode<state_type, in_type, out_type>(%cell, %input : List[in_type], %init : state_type) -> state_type {
  # not using the output
  @foldl(fn(%state, %in) { %cell(%state, %in).0 }, %init, %input)
}

使用一个展开迭代器(来自Haskell的标准库),相同的单元可以用来创建一个生成器网络(它接受一个输入并产生一个输出序列):

# included in Relay's Prelude
def @unfoldr<a, b>(%f : fn(b) -> Optional[(a, b)], %z : b) -> List[a] {
  match(%f(%z)) {
    case Some(%pair) { Cons(%pair.0, @unfoldr(%f, %pair.1)) }
    case None() { Nil() }
  }
}

# we need some way of generating an input to the cell function given only a state
def @gen_func<state_type, in_type, out_type>(%state : state_type) : Optional[(out_type, state_type)] {
  let %in : Optional[in_type] = @generate_input(%state);
  match(%in) {
    case Some(%n) {
      let %cell_out = @cell(%n, %state);
      Some((%cell_out.1, %cell_out.0)) # pair of output and state
    }
    case None() { None() }
  }
}

def @generator<state_type, in_type, out_type>(%cell, %init : state_type) -> List[out_type] {
  @unfoldr(fn(%state) { @gen_func(%cell, %state) }, %init)
}

一个累加映射(同时更新累加器值和输出列表的折叠)可以用来编写一个通用的RNN(每个输入都有一个输出):

def @map_accumr<a, b, c>(%f : fn(a, b) -> (a, c), %acc : a, %l : List[b]) -> (a, List[c]) {
  match(%l) {
    case Nil() { (%acc, Nil()) }
    case Cons(%b, %t) {
      let %update = %f(%acc, %b);
      let %rest = @map_accumr(%f, %update.0, %t));
      (%rest.0, Cons(%update.1, %rest.1))
    }
  }
}

# can also be implemented as a right fold
# (this version is included in Relay's Prelude)
def @map_accumr_fold(%f, %acc, %l) {
  @foldr(fn(%b, %p) {
    let %f_out = %f(%p.0, %b);
    (%f_out.0, Cons(%f_out.1, %p.1))
  },
  (%acc, Nil()), %l)
}

def @general_rnn<state_type, in_type, out_type>(%cell, %init : state_type, %input : List[in_type])
  -> (state_type, List[out_type]) {
  @map_accumr(%cell, %init, %input)
}

Olah还给出了一个双向神经网络的例子,其中两组细胞(可能有不同的权重)在两个方向上处理输入,并产生一组输出。下面是该示例的一个Relay实现:

# creates a list of tuples from two lists
# included in Relay's Prelude
def @zip<a, b>(%l : List[a], %m : List[b]) -> List[(a, b)] {
  match(%l) {
    case Nil() { Nil() }
    case Cons(%a, %t1) {
      match(%m) {
        case Nil() { Nil() }
        case Cons(%b, %t2) { Cons((%a, %b), @zip(%t1, %t2)) }
      }
    }
  }
}

# analogous to map_accumr
# included in Relay's Prelude
def @map_accmul(%f, %acc, %l) {
  @foldl(fn(%p, %b){
    let %f_out = %f(%p.0, %b);
    (%f_out.0, Cons(%f_out.1, %p.1))
  }, (%acc, Nil()), %l)
}

def @bidirectional_rnn<state1_type, state2_type, in_type, out1_type, out2_type>
  (%cell1, %cell2, %state1 : state1_type, %state2 : state2_type, %input : List[in_type])
  -> List[(out1_type, out2_type)] {
  @zip(@map_accumr(%cell1, %state1, %input).1, @map_accuml(%cell2, %state2, %input).1)
}

 类似资料: