Go 博客

Go 并发模式:管道与取消

Sameer Ajmani
2014 年 3 月 13 日

引言

Go 的并发原语使得构建流式数据管道变得容易,这些管道可以有效地利用 I/O 和多个 CPU。本文提供了此类管道的示例,强调了操作失败时出现的细微之处,并介绍了优雅处理故障的技术。

什么是管道?

在 Go 中没有管道的正式定义;它只是众多并发程序类型中的一种。非正式地讲,管道是由通道连接的一系列阶段,每个阶段是一组运行相同函数的 goroutine。在每个阶段中,goroutine 会

  • 通过入站通道从上游接收值
  • 对数据执行某些函数,通常会产生新值
  • 通过出站通道将值发送到下游

每个阶段都有任意数量的入站和出站通道,但第一阶段和最后阶段除外,它们分别只有出站或入站通道。第一阶段有时称为生产者;最后阶段称为消费者

我们将从一个简单的管道示例开始,以解释这些思想和技术。稍后,我们将介绍一个更实际的示例。

数字平方

考虑一个有三个阶段的管道。

第一阶段,gen,是一个将整数列表转换为发射列表中整数的通道的函数。gen 函数启动一个 goroutine,将整数发送到通道,并在所有值发送完毕后关闭通道

func gen(nums ...int) <-chan int {
    out := make(chan int)
    go func() {
        for _, n := range nums {
            out <- n
        }
        close(out)
    }()
    return out
}

第二阶段,sq,从通道接收整数并返回一个通道,该通道发射接收到的每个整数的平方。在入站通道关闭且此阶段将所有值发送到下游后,它会关闭出站通道

func sq(in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        for n := range in {
            out <- n * n
        }
        close(out)
    }()
    return out
}

main 函数设置管道并运行最后阶段:它从第二阶段接收值并打印每一个值,直到通道关闭

func main() {
    // Set up the pipeline.
    c := gen(2, 3)
    out := sq(c)

    // Consume the output.
    fmt.Println(<-out) // 4
    fmt.Println(<-out) // 9
}

由于 sq 的入站和出站通道类型相同,我们可以多次组合它。我们也可以将 main 重写为范围循环,就像其他阶段一样

func main() {
    // Set up the pipeline and consume the output.
    for n := range sq(sq(gen(2, 3))) {
        fmt.Println(n) // 16 then 81
    }
}

扇出、扇入

多个函数可以从同一个通道读取,直到该通道关闭;这称为扇出。这提供了一种将工作分配给一组工作者以并行化 CPU 使用和 I/O 的方法。

一个函数可以通过将输入通道复用到一个单一通道(该通道在所有输入通道都关闭时关闭),从多个输入读取并继续执行直到所有输入都关闭。这称为扇入

我们可以修改我们的管道以运行两个 sq 实例,每个实例从同一个输入通道读取。我们引入一个新函数 merge 来扇入结果

func main() {
    in := gen(2, 3)

    // Distribute the sq work across two goroutines that both read from in.
    c1 := sq(in)
    c2 := sq(in)

    // Consume the merged output from c1 and c2.
    for n := range merge(c1, c2) {
        fmt.Println(n) // 4 then 9, or 9 then 4
    }
}

merge 函数通过为每个入站通道启动一个 goroutine 将值复制到唯一的出站通道,从而将通道列表转换为单个通道。一旦所有 output goroutine 都已启动,merge 会再启动一个 goroutine,以便在该通道上的所有发送操作完成后关闭出站通道。

在已关闭的通道上发送会引发 panic,因此在调用 close 之前确保所有发送操作完成非常重要。sync.WaitGroup 类型提供了一种安排此同步的简单方法

func merge(cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int)

    // Start an output goroutine for each input channel in cs.  output
    // copies values from c to out until c is closed, then calls wg.Done.
    output := func(c <-chan int) {
        for n := range c {
            out <- n
        }
        wg.Done()
    }
    wg.Add(len(cs))
    for _, c := range cs {
        go output(c)
    }

    // Start a goroutine to close out once all the output goroutines are
    // done.  This must start after the wg.Add call.
    go func() {
        wg.Wait()
        close(out)
    }()
    return out
}

中途停止

我们的管道函数有一个模式

  • 阶段在所有发送操作完成后关闭其出站通道。
  • 阶段不断从入站通道接收值,直到这些通道关闭。

这种模式允许每个接收阶段写成一个 range 循环,并确保一旦所有值都已成功发送到下游,所有 goroutine 都会退出。

但在实际的管道中,阶段并不总是接收所有入站值。有时这是故意设计的:接收方可能只需要值的子集就能继续进行。更常见的情况是,阶段提前退出,因为入站值表示先前阶段发生了错误。无论哪种情况,接收方都不必等待剩余的值到达,并且我们希望较早的阶段停止产生后续阶段不需要的值。

在我们的示例管道中,如果某个阶段未能消费所有入站值,则尝试发送这些值的 goroutine 将无限期阻塞

    // Consume the first value from the output.
    out := merge(c1, c2)
    fmt.Println(<-out) // 4 or 9
    return
    // Since we didn't receive the second value from out,
    // one of the output goroutines is hung attempting to send it.
}

这会导致资源泄露:goroutine 消耗内存和运行时资源,goroutine 堆栈中的堆引用会阻止数据被垃圾回收。goroutine 本身不会被垃圾回收;它们必须自行退出。

我们需要安排管道的上游阶段即使在下游阶段未能接收所有入站值时也能退出。一种方法是将出站通道更改为带缓冲区。缓冲区可以容纳固定数量的值;如果缓冲区中有空间,发送操作会立即完成

c := make(chan int, 2) // buffer size 2
c <- 1  // succeeds immediately
c <- 2  // succeeds immediately
c <- 3  // blocks until another goroutine does <-c and receives 1

在创建通道时已知要发送的值的数量时,缓冲区可以简化代码。例如,我们可以重写 gen 将整数列表复制到带缓冲的通道中,并避免创建新的 goroutine

func gen(nums ...int) <-chan int {
    out := make(chan int, len(nums))
    for _, n := range nums {
        out <- n
    }
    close(out)
    return out
}

回到我们管道中被阻塞的 goroutine,我们可能会考虑为 merge 返回的出站通道添加一个缓冲区

func merge(cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int, 1) // enough space for the unread inputs
    // ... the rest is unchanged ...

虽然这解决了此程序中被阻塞的 goroutine,但这代码不好。这里选择缓冲区大小为 1 取决于知道 merge 将接收多少值以及下游阶段将消费多少值。这是脆弱的:如果我们向 gen 传递额外的值,或者如果下游阶段读取的值减少了,我们又会遇到阻塞的 goroutine。

相反,我们需要提供一种方法,让下游阶段能够向上游发送方指示它们将停止接受输入。

显式取消

main 决定在未接收到 out 中的所有值时退出,它必须告诉上游阶段的 goroutine 放弃它们尝试发送的值。它通过向一个名为 done 的通道发送值来实现。它发送两个值,因为潜在有两个被阻塞的发送方

func main() {
    in := gen(2, 3)

    // Distribute the sq work across two goroutines that both read from in.
    c1 := sq(in)
    c2 := sq(in)

    // Consume the first value from output.
    done := make(chan struct{}, 2)
    out := merge(done, c1, c2)
    fmt.Println(<-out) // 4 or 9

    // Tell the remaining senders we're leaving.
    done <- struct{}{}
    done <- struct{}{}
}

发送 goroutine 将其发送操作替换为 select 语句,该语句在向 out 发送发生时或当它们从 done 接收到值时继续执行。done 的值类型为空结构体,因为值不重要:是接收事件指示应放弃向 out 的发送。output goroutine 继续在其入站通道 c 上循环,因此上游阶段不会被阻塞。(我们稍后会讨论如何允许此循环提前返回。)

func merge(done <-chan struct{}, cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int)

    // Start an output goroutine for each input channel in cs.  output
    // copies values from c to out until c is closed or it receives a value
    // from done, then output calls wg.Done.
    output := func(c <-chan int) {
        for n := range c {
            select {
            case out <- n:
            case <-done:
            }
        }
        wg.Done()
    }
    // ... the rest is unchanged ...

这种方法存在一个问题:每个下游接收方需要知道潜在阻塞的上游发送方的数量,并安排在提前返回时向这些发送方发出信号。跟踪这些计数是乏味且容易出错的。

我们需要一种方法来告诉未知且数量不限的 goroutine 停止向下游发送它们的值。在 Go 中,我们可以通过关闭通道来做到这一点,因为 在已关闭通道上的接收操作总是能立即进行,产生元素类型的零值。

这意味着 main 只需关闭 done 通道即可解除所有发送方的阻塞。这种关闭实际上是对发送方的广播信号。我们将管道函数每个都扩展为接受 done 作为参数,并通过 defer 语句安排关闭操作,以便 main 的所有返回路径都将信号通知管道阶段退出。

func main() {
    // Set up a done channel that's shared by the whole pipeline,
    // and close that channel when this pipeline exits, as a signal
    // for all the goroutines we started to exit.
    done := make(chan struct{})
    defer close(done)          

    in := gen(done, 2, 3)

    // Distribute the sq work across two goroutines that both read from in.
    c1 := sq(done, in)
    c2 := sq(done, in)

    // Consume the first value from output.
    out := merge(done, c1, c2)
    fmt.Println(<-out) // 4 or 9

    // done will be closed by the deferred call.      
}

现在,我们管道的每个阶段都可以立即返回,只要 done 关闭。merge 中的 output 例程可以返回而无需耗尽其入站通道,因为它知道上游发送方 sqdone 关闭时将停止尝试发送。output 通过 defer 语句确保在所有返回路径上调用 wg.Done

func merge(done <-chan struct{}, cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int)

    // Start an output goroutine for each input channel in cs.  output
    // copies values from c to out until c or done is closed, then calls
    // wg.Done.
    output := func(c <-chan int) {
        defer wg.Done()
        for n := range c {
            select {
            case out <- n:
            case <-done:
                return
            }
        }
    }
    // ... the rest is unchanged ...

类似地,sq 可以在 done 关闭后立即返回。sq 通过 defer 语句确保在其所有返回路径上关闭其 out 通道

func sq(done <-chan struct{}, in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for n := range in {
            select {
            case out <- n * n:
            case <-done:
                return
            }
        }
    }()
    return out
}

以下是管道构建的指导方针

  • 阶段在所有发送操作完成后关闭其出站通道。
  • 阶段持续从入站通道接收值,直到这些通道关闭或发送方解除阻塞。

管道通过确保有足够的缓冲区容纳所有发送的值,或者在接收方可能放弃通道时明确通知发送方来解除发送方的阻塞。

遍历并摘要文件树

让我们考虑一个更实际的管道。

MD5 是一种消息摘要算法,常用作文件校验和。命令行工具 md5sum 可以打印一系列文件的摘要值。

% md5sum *.go
d47c2bbc28298ca9befdfbc5d3aa4e65  bounded.go
ee869afd31f83cbb2d10ee81b2b831dc  parallel.go
b88175e65fdcbc01ac08aaf1fd9b5e96  serial.go

我们的示例程序类似于 md5sum,但它接受单个目录作为参数,并按路径名排序后打印该目录下每个普通文件的摘要值。

% go run serial.go .
d47c2bbc28298ca9befdfbc5d3aa4e65  bounded.go
ee869afd31f83cbb2d10ee81b2b831dc  parallel.go
b88175e65fdcbc01ac08aaf1fd9b5e96  serial.go

我们的程序的主函数调用一个辅助函数 MD5All,该函数返回一个从路径名到摘要值的映射,然后对结果进行排序并打印

func main() {
    // Calculate the MD5 sum of all files under the specified directory,
    // then print the results sorted by path name.
    m, err := MD5All(os.Args[1])
    if err != nil {
        fmt.Println(err)
        return
    }
    var paths []string
    for path := range m {
        paths = append(paths, path)
    }
    sort.Strings(paths)
    for _, path := range paths {
        fmt.Printf("%x  %s\n", m[path], path)
    }
}

MD5All 函数是我们讨论的重点。在 serial.go 中,实现没有使用并发,它只是在遍历文件树时简单地读取和计算每个文件的摘要。

// MD5All reads all the files in the file tree rooted at root and returns a map
// from file path to the MD5 sum of the file's contents.  If the directory walk
// fails or any read operation fails, MD5All returns an error.
func MD5All(root string) (map[string][md5.Size]byte, error) {
    m := make(map[string][md5.Size]byte)
    err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
        if err != nil {
            return err
        }
        if !info.Mode().IsRegular() {
            return nil
        }
        data, err := ioutil.ReadFile(path)
        if err != nil {
            return err
        }
        m[path] = md5.Sum(data)
        return nil
    })
    if err != nil {
        return nil, err
    }
    return m, nil
}

并行摘要计算

parallel.go 中,我们将 MD5All 分成一个两阶段的管道。第一阶段,sumFiles,遍历文件树,在新的 goroutine 中对每个文件计算摘要,并将结果发送到值类型为 result 的通道上

type result struct {
    path string
    sum  [md5.Size]byte
    err  error
}

sumFiles 返回两个通道:一个用于 results,另一个用于 filepath.Walk 返回的错误。遍历函数为处理每个普通文件启动一个新的 goroutine,然后检查 done。如果 done 关闭,遍历会立即停止

func sumFiles(done <-chan struct{}, root string) (<-chan result, <-chan error) {
    // For each regular file, start a goroutine that sums the file and sends
    // the result on c.  Send the result of the walk on errc.
    c := make(chan result)
    errc := make(chan error, 1)
    go func() {
        var wg sync.WaitGroup
        err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }
            if !info.Mode().IsRegular() {
                return nil
            }
            wg.Add(1)
            go func() {
                data, err := ioutil.ReadFile(path)
                select {
                case c <- result{path, md5.Sum(data), err}:
                case <-done:
                }
                wg.Done()
            }()
            // Abort the walk if done is closed.
            select {
            case <-done:
                return errors.New("walk canceled")
            default:
                return nil
            }
        })
        // Walk has returned, so all calls to wg.Add are done.  Start a
        // goroutine to close c once all the sends are done.
        go func() {
            wg.Wait()
            close(c)
        }()
        // No select needed here, since errc is buffered.
        errc <- err
    }()
    return c, errc
}

MD5Allc 接收摘要值。MD5All 在发生错误时提前返回,通过 defer 关闭 done

func MD5All(root string) (map[string][md5.Size]byte, error) {
    // MD5All closes the done channel when it returns; it may do so before
    // receiving all the values from c and errc.
    done := make(chan struct{})
    defer close(done)          

    c, errc := sumFiles(done, root)

    m := make(map[string][md5.Size]byte)
    for r := range c {
        if r.err != nil {
            return nil, r.err
        }
        m[r.path] = r.sum
    }
    if err := <-errc; err != nil {
        return nil, err
    }
    return m, nil
}

受限的并行性

parallel.go 中的 MD5All 实现为每个文件启动一个新的 goroutine。在一个包含许多大文件的目录中,这可能会分配比机器可用内存更多的内存。

我们可以通过限制并行读取的文件数量来限制这些分配。在 bounded.go 中,我们通过创建固定数量的 goroutine 来读取文件。我们的管道现在有三个阶段:遍历文件树、读取并计算文件摘要、收集摘要。

第一阶段,walkFiles,发射文件树中普通文件的路径

func walkFiles(done <-chan struct{}, root string) (<-chan string, <-chan error) {
    paths := make(chan string)
    errc := make(chan error, 1)
    go func() {
        // Close the paths channel after Walk returns.
        defer close(paths)
        // No select needed for this send, since errc is buffered.
        errc <- filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }
            if !info.Mode().IsRegular() {
                return nil
            }
            select {
            case paths <- path:
            case <-done:
                return errors.New("walk canceled")
            }
            return nil
        })
    }()
    return paths, errc
}

中间阶段启动固定数量的 digester goroutine,它们从 paths 接收文件名,并通过通道 c 发送 results

func digester(done <-chan struct{}, paths <-chan string, c chan<- result) {
    for path := range paths {
        data, err := ioutil.ReadFile(path)
        select {
        case c <- result{path, md5.Sum(data), err}:
        case <-done:
            return
        }
    }
}

与我们之前的示例不同,digester 不会关闭其输出通道,因为有多个 goroutine 在共享通道上发送。相反,MD5All 中的代码会在所有 digesters 完成后安排关闭通道

    // Start a fixed number of goroutines to read and digest files.
    c := make(chan result)
    var wg sync.WaitGroup
    const numDigesters = 20
    wg.Add(numDigesters)
    for i := 0; i < numDigesters; i++ {
        go func() {
            digester(done, paths, c)
            wg.Done()
        }()
    }
    go func() {
        wg.Wait()
        close(c)
    }()

我们也可以让每个 digester 创建并返回其自己的输出通道,但那样我们就需要额外的 goroutine 来扇入结果。

最后阶段从 c 接收所有 results,然后检查来自 errc 的错误。此检查不能提前进行,因为在此之前,walkFiles 可能会因向下游发送值而阻塞

    m := make(map[string][md5.Size]byte)
    for r := range c {
        if r.err != nil {
            return nil, r.err
        }
        m[r.path] = r.sum
    }
    // Check whether the Walk failed.
    if err := <-errc; err != nil {
        return nil, err
    }
    return m, nil
}

结论

本文介绍了在 Go 中构建流式数据管道的技术。处理此类管道中的故障很棘手,因为管道中的每个阶段都可能因尝试向下游发送值而阻塞,并且下游阶段可能不再关心传入的数据。我们展示了如何通过关闭通道向管道启动的所有 goroutine 广播“完成”信号,并定义了正确构建管道的指导方针。

延伸阅读

下一篇文章:Go 地鼠
上一篇文章:FOSDEM 2014 上的 Go 演讲
博客索引