Go Errorgroup


原文链接: Go Errorgroup

https://github.com/bketelsen/gogrep
Run strikingly fast parallel file searches in Go with sync.ErrGroup - O'Reilly Media

【翻译】使用 sync.ErrGroup 实现并发搜索文件 - Go中国技术社区 - golang
Go 的一个很重要的的特性就是其原生的并发,像 channel 和 goroutines 这样的利器。但是对于一个新手来说 goroutines 这个概念可能比较陌生。

Go 团队发布的第一个 goroutines 的管理工具是 sync.WaitGroup,这个工具允许你创建 WaitGroup 去等待一定数量的 goroutines 执行完成。这里有个例子:

var wg sync.WaitGroup
var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/",
}
for _, url := range urls {
        // Increment the WaitGroup counter.
        wg.Add(1)
        // Launch a goroutine to fetch the URL.
        go func(url string) {
                // Decrement the counter when the goroutine completes.
                defer wg.Done()
                // Fetch the URL.
                http.Get(url)
        }(url)
}
// Wait for all HTTP fetches to complete.
wg.Wait()

WaitGroup 使你在处理并发任务时对 goroutines 的创建和停止的数量控制都变的更加简单。每次你创建 goroutine 的时候只要调用 Add() 就可以了。当这个任务结束调用 wg.Done()。等待所有的任务完成,调用 wg.Wait()。但是用 WatiGroup 唯一的问题就是当你的 goroutines 出错时,你不能捕获到错误的原因。

sync.WaitGroup 的加强版

最近 Go 的团队在试验的代码仓库中增加了一个包叫 sync.ErrGroup。sync.ErrGroup 相当于为 sync.WaitGroup 增加了错误返回的功能。下面我们来看一个同样的例子:

var g errgroup.Group
var urls = []string{
    "",
    "",
    "",
}
for _, url := range urls {
    url := url
    g.Go(func() error {
        resp, err := http.Get(url)
        if err == nil {
            resp.Body.Close()
        }
        return err
    })
}
if err := g.Wait(); err == nil {
    fmt.Println("Successfully fetched all URLs.")
}

g.Go() 这个方法不仅允许你传一个匿名的函数,而且还能捕获错误信息,你只要像这样返回一个错误 return err。这对使用 goroutines 的开发者来说在功能上是一个很大的提升。

为了测试 sync.ErrGroup 的所有特性, 我已经写了一个递归搜索指定目录的Go程序。并且我还加了一个超时的机制 。当程序超时了,所有的 goroutines 将被取消,程序退出。

当执行程序的结果如下:

$ gogrep -timeout 1000ms . fmt
gogrep.go
1 hits

如果你没有使用正确的参数,将会输出正确的使用方法:

gogrep by Brian Ketelsen
Flags:
-timeout duration

    timeout in milliseconds (default 500ms)

Usage:

gogrep [flags] path pattern

sync.ErrGroup 如何是我们程序写的更加简单呢

让我们看看我们是如何利用 sync.ErrGroup 来使程序写的更简单。我们将从一个 main() 函数开始,因为我喜欢像写故事一样写代码,每个代码的故事都是从 main() 函数开始。

package main

import (
    "bytes"
    "flag"
    "fmt"
    "io/ioutil"
    "log"
    "os"
    "path/filepath"
    "strings"
    "time"

    "golang.org/x/net/context"
    "golang.org/x/sync/errgroup"
)

func main() {
    duration := flag.Duration("timeout", 500*time.Millisecond, "timeout in milliseconds")
    flag.Usage = func() {
        fmt.Printf("%s by Brian Ketelsen\n", os.Args[0])
        fmt.Println("Usage:")
        fmt.Printf("    gogrep [flags] path pattern \n")
        fmt.Println("Flags:")
        flag.PrintDefaults()

    }
    flag.Parse()
    if flag.NArg() != 2 {
        flag.Usage()
        os.Exit(-1)
    }
    path := flag.Arg(0)
    pattern := flag.Arg(1)
    ctx, _ := context.WithTimeout(context.Background(), *duration)
    m, err := search(ctx, path, pattern)
    if err != nil {
        log.Fatal(err)
    }
    for _, name := range m {
        fmt.Println(name)
    }
    fmt.Println(len(m), "hits")
}

程序的前15行对命令行传进来的参数做了解析。第一段比较又去的代码是在第16行:

ctx, _ := context.WithTimeout(context Backgroud(), *duration)

这里,我给 context.Context 加了一个超时的时间。这个超时变量是通过 duration 来设置的。当超时时间到了,"ctx" 将接受到 channel 的超时警告。WithTimeout 同样也会返回一个取消的方法,但是我们不需要,所以用 "_" 来取消了。

下面 search() 方法的参数有 context, search path, 和 search pattern。最后把找到的文件和数量输出到终端上。

分解 search() 方法:

这个 search() 函数比 main() 函数要长,所以我们把它分解开解释。

首先 search() 函数创建了一个新的 errgroup。

func search(ctx context.Context, root string, pattern string) ([]string, error) {

g, ctx := errgroup.WithContext(ctx)

下面我创建了 channel 用来传递被搜索到的文件。稍后我们将发送搜索到的文件到 channel 中去判断这些文件是否符合 pattern 参数。这个 channel 开启的buffer 数为100

paths := make(chan string, 100)

这个 errgroup 类型有两个方法:Wait() 和 Go()。Go() 创建一个任务,Wait() 等待所有的任务完成。现在我们调用参数为匿名函数、返回值为error的 Go() 方法。

当所有的目录搜索完成时,我们将用 defer 来关闭 "paths" channel。后续我们将在更多的 goroutines 中使用这些文件。

defer close(paths)

最后我们使用 filepath 包提供的 Walk() 方法去递归查找指定目录的所有文件。我们将检查这些文件是否可读,是否带有 ".go" 后缀的文件。

return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {

        if err != nil {
            return err
        }
        if !info.Mode().IsRegular() {
            return nil
        }
        if !info.IsDir() && !strings.HasSuffix(info.Name(), ".go") {
            return nil
        }

发现 sync.Errgroup 真正的能力:

上面的那些过滤条件将丢掉不符合的文件。现在我们正式展示一下 sync.Errgroup 的真正能力。我在 select 里放了两个条件。首先发送 path 到 paths channel 里,另外的 goroutine 将会接收这个 channel 里的数据。第二个事件就是等待 context 超时发生。只要没到超时时间,就会继续处理文件。当超时后,context 的 Done channel 将发送数据导致 goroutine 返回,这个返回将停止文件搜索。

    select {
        case paths <- path:
        case <-ctx.Done():
            return ctx.Err()
        }
        return nil
    })

})

下面我将创建一个 channel 去处理文件是否符合 patter 参数。

c := make(chan string,100)

现在我将遍历所有的文件查找他们的内容:

for path := range paths {

这里我将说明的是我们将为每个文件创建一个 goroutine 去做模式匹配:

    p := path
    g.Go(func() error {
        data, err := ioutil.ReadFile(p)
        if err != nil {
            return err
        }
        if !bytes.Contains(data, []byte(pattern)) {
            return nil
        }

我们将再次使用 select 去等待处理完成的文件和超时。

        select {
        case c <- p:
        case <-ctx.Done():
            return ctx.Err()
        }
        return nil
    })
}

这个函数将会等待所有的 errgroup 的 goroutines 全部完成后关闭结果的 channel。

go func() {
    g.Wait()
    close(c)
}()

现在我们可以收集到所有的文件了。

var m []string
for r := range c {
    m = append(m, r)
}

最后我们将这些文件返回给 main() 函数作为结果:

return m, g.Wait()

}

这只是个简单的例子,完整的代码在 Github 上

`