go线程池

最近工作用到go 多线程相关, 想实现一个qps可控的线程池,仿照java的线程池写了一个,被同事评为java味过浓…

新手,会有很多错误,请大家多指教!

代码

package threads

import (
	"fmt"
	"go.uber.org/atomic"
	"time"
)

// Task 任务
type Task interface {
	// Run 执行认为
	Run()
}

// Pool 线程池
type Pool struct {
	// 任务队列
	taskChannel chan Task
	// 最大worker数量
	capacity int
	// go runtime数组
	workers []int
	// status
	starting *atomic.Bool
	// 执行任务数量
	count *atomic.Int64
}

func NewPool(capacity int) *Pool {
	pool := &Pool{
		taskChannel: make(chan Task),
		capacity:    capacity,
		starting:    atomic.NewBool(false),
		count:       atomic.NewInt64(0),
	}
	(*pool).initWorker()
	return pool
}

func (p *Pool) initWorker() {
	if !p.starting.CAS(false, true) {
		return
	}
	//fmt.Println("启动worker!")
	for i := 0; i < p.capacity; i++ {
		go p.work(i)
		//fmt.Printf("worker %v ready!\n", i)
	}
}

func (p *Pool) work(workerId int) {
	for {
		task, running := <-p.taskChannel
		if !running {
			//fmt.Printf("worker %v has closed! \n", workerId)
			break
		}
		if task == nil {
			fmt.Println("task is nil, continue!")
			continue
		}
		task.Run()
		p.count.Sub(1)
	}
}

func (p *Pool) AddTask(task Task) bool {
	if !p.starting.Load() {
		fmt.Println("线程池已关闭!")
		return false
	}
	p.taskChannel <- task
	p.count.Add(1)
	//fmt.Println("添加一个任务!")
	return true
}

// Close 关闭线程池 until: 超时时间,最小10ms
func (p *Pool) Close(until time.Duration) {
	//fmt.Println("准备关闭!")
	if !p.starting.CAS(true, false) {
		//fmt.Println("不可重复关闭!")
		return
	}
	//fmt.Println("不再接收新的任务!")
	if until.Milliseconds() > 0 {
		begin := time.Now()
		end := begin.Add(until)
		for now := begin; now.Before(end); now = time.Now() {
			if p.count.Load() <= 0 {
				//fmt.Println("任务执行完成!")
				break
			}
			time.Sleep(time.Millisecond * 10)
		}
		if p.count.Load() > 0 {
			fmt.Printf("即将结束线程组, 还有 %v 个任务未执行!\n", p.count.Load())
		} /*else {
			fmt.Println("即将结束线程池!")
		}*/
	} /*else {
		fmt.Println("不等待, 立即结束线程组!")
	}*/
	close(p.taskChannel)
	//fmt.Println("线程池结束!")
}

测试

package threads

import (
	"fmt"
	"testing"
	"time"
)

type TestTask struct {
	Name string
}

func (task *TestTask) Run() {
	fmt.Printf("任务 %v 执行中! \n", task.Name)
}

func TestPool(t *testing.T) {
	pool := NewPool(10)
	for i := 0; i < 10000; i++ {
		name := fmt.Sprintf("%v", i)
		task := TestTask{Name: name}
		pool.AddTask(&task)
	}
	pool.Close(time.Microsecond * 6000)
}