1 pool.go
package pool
import (
"errors"
"log"
"io"
"sync"
)
type Pool struct {
m sync.Mutex
resources chan io.Closer //传输io.Closer接口类型的数据
factory func() (io.Closer, error) //构造函数
closed bool
}
var ErrPoolClosed = errors.New("Pool has been closed.")
func New(fn func()(io.Closer, error), size uint)(*Pool, error){
if size <= 0 {
return nil, errors.New("Size value tool small.")
}
return &Pool {
factory: fn,
resources: make(chan io.Closer, size), //创建有缓冲的通道
}, nil
}
func (p *Pool)Acquire() (io.Closer, error) {
select {
case r, ok := <-p.resources: //从通道中获取资源,如果获取不到就走default分支,创建新的资源,不会阻塞
log.Println("Acquire:", "Shared Resource")
if !ok {
return nil, ErrPoolClosed
}
return r, nil
default:
log.Println("Acquire:", "New Resource")
return p.factory()
}
}
func (p *Pool) Release(r io.Closer){
p.m.Lock()
defer p.m.Unlock()
if p.closed{ //如果池已关闭,则资源关闭
r.Close()
return
}
select {
case p.resources <-r: 将资源放入池中,如果池已经满了,则走default分支,不会阻塞
log.Println("Release:", "In Queue")
default:
log.Println("Release:","Closing") //如果池满了,就释放资源,说明资源充足,没有必要再留了
r.Close()
}
}
func (p *Pool) Close(){
p.m.Lock()
defer p.m.Unlock()
if p.closed{
return
}
p.closed = true
close(p.resources)
for r := range p.resources {
r.Close()
}
}
2 main.go
package main
import (
"log"
"io"
"math/rand"
"sync"
"sync/atomic"
"time"
"ch7/pool"
)
const (
maxGoroutines = 25
pooledResources = 2
)
type dbConnection struct {
ID int32
}
func (dbConn *dbConnection) Close() error { //实现io.Closer接口
log.Println("Close: Connection", dbConn.ID)
return nil
}
var idCounter int32
func createConnection()(io.Closer, error){ //创建工程函数
id := atomic.AddInt32(&idCounter, 1)
log.Println("Create: New Connection",id)
return &dbConnection{id},nil
}
func main() {
var wg sync.WaitGroup
wg.Add(maxGoroutines)
p, err := pool.New(createConnection, pooledResources)
if err != nil {
log.Println(err)
}
for query := 0; query < maxGoroutines; query++ {
go func(q int){
performQueries(q,p)
wg.Done()
}(query)
}
wg.Wait()
log.Println("Shutdown Program.")
p.Close()
}
func performQueries(query int, p *pool.Pool){
conn, err := p.Acquire()
if err != nil {
log.Println(err)
return
}
defer p.Release(conn)
time.Sleep(time.Duration(rand.Intn(1000))* time.Millisecond)
log.Printf("QID[%d] CID[%d]\n", query,conn.(*dbConnection).ID)
}