基于Golang实现反向代理服务器(3)健康检查与动态后端管理

回顾与目标:

在之前的实现中, Backend 结构体中已经有了一个 Alive 字段,但是我们从来没有更新过它。所有后端默认都是 Alive = true 。如果某个后端服务挂了,代理仍然会把请求转发给它,导致502错误。

因此,在本章中,我们期望实现

  1. 启动一个后台 goroutine ,定期(比如每5秒)向所有后端发送健康检查请求。
  2. 如果检查成功,将对应后端的 Alive 设为 true ;如果失败,则设为 false 。
  3. 修改负载均衡算法,只选择 Alive = true 的后端。(当然在实际实现中我们将其改为 Int32 ,这样比较好实现)
  4. 确保健康检查和负载均衡之间的数据访问是并发安全的。

一、关键技术点

1. 定时器 time.Ticker

time.NewTicker(duration) 会返回一个通道,每隔指定时间就会往通道发送一个时间值。

当通道收到信号时,进行健康检查。

2. 并发安全地更新 Alive

Alive 字段会被健康检查地 goroutine 写入,同时会被多个处理请求的 goroutine 读取。因此在这里需要使用原子操作

二、实现健康检查

1. 添加健康检查辅助方法

代码如下:

func (b *Backend) IsAlive() bool {
	return atomic.LoadInt32(&b.Alive) == 1
}

func (b *Backend) SetAlive(alive bool) {
	if alive {
		atomic.StoreInt32(&b.Alive, 1)
	} else {
		atomic.StoreInt32(&b.Alive, 0)
	}
}

2. 修改负载均衡算法,只返回存活的节点

代码如下:

// 轮询
func (lb *LoadBalancer) NextRoundRobin() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	start := atomic.AddUint64(&lb.current, 1) % uint64(n)
	for i := 0; i < n; i++ {
		idx := (int(start) + i) % n
		backend := lb.backends[idx]
		if backend.IsAlive() {
			return backend
		}
	}
	return nil
}

// 随机
func (lb *LoadBalancer) NextRandom() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	// 最多尝试 n 次
	for i := 0; i < n; i++ {
		lb.randMu.Lock()
		idx := lb.randSrc.Intn(n)
		lb.randMu.Unlock()
		backend := lb.backends[idx]
		if backend.IsAlive() {
			return backend
		}
	}
	return nil
}

// 最少连接数
func (lb *LoadBalancer) NextLeastConnections() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	var selected *Backend
	var minConns int64 = 1<<63 - 1

	for _, b := range lb.backends {
		if !b.IsAlive() {
			continue
		}
		conns := atomic.LoadInt64(&b.Connections)
		if conns < minConns {
			minConns = conns
			selected = b
		}
	}
	return selected
}

// IP 哈希
func (lb *LoadBalancer) NextIPHash(ip string) *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	h := fnv.New32a()
	h.Write([]byte(ip))
	hash := h.Sum32()
	idx := int(hash) % n

	// 如果选中的后端不存活,可以尝试线性探测下一个存活的后端
	for i := 0; i < n; i++ {
		candidate := lb.backends[(idx+i)%n]
		if candidate.IsAlive() {
			return candidate
		}
	}
	return nil
}

3. 添加健康检查方法

在 LoadBalancer 上添加一个 StartHealthCheck 方法,启动后台 goroutine 。

同时,为了在程序退出时能够正确地停止后台 goroutine , 我们可以给 LoadBalancer 增加一个停止通道。

注意:健康检查的路径 /health 只是一个例子,实际后端中可能提供不同的健康检查接口,也可以设计成允许配置。

func (lb *LoadBalancer) StartHealthCheck(interval time.Duration) {
	go func() {
		ticker := time.NewTicker(interval)
		defer ticker.Stop()
		for {
			select {
			case <-ticker.C:
				lb.healthCheck()
			case <-lb.stopChan:
				return
			}
		}
	}()
}

func (lb *LoadBalancer) Stop() {
	close(lb.stopChan)
}

func (lb *LoadBalancer) healthCheck() {
	lb.mu.RLock()
	backends := lb.backends
	lb.mu.RUnlock()

	var wg sync.WaitGroup
	for _, b := range backends {
		wg.Add(1)
		go func(backend *Backend) {
			defer wg.Done()
			alive := checkBackendHealth(backend.URL)
			backend.SetAlive(alive)
			if !alive {
				log.Printf("Backend %s is DOWN", backend.URL.String())
			} else {
				// 可以记录恢复日志,但避免太多
				// log.Printf("Backend %s is UP", backend.URL.String())
			}
		}(b)
	}
	wg.Wait()
}

func checkBackendHealth(u *url.URL) bool {
	client := http.Client{
		Timeout: 2 * time.Second,
	}
	// 假设健康检查路径为 /health
	healthURL := u.ResolveReference(&url.URL{Path: "/health"})
	resp, err := client.Head(healthURL.String())
	if err != nil {
		return false
	}
	defer resp.Body.Close()
	return resp.StatusCode >= 200 && resp.StatusCode < 300
}

三、重点代码解析

在健康检查部分中,有如下代码:

for {
    select {
    case <-ticker.C:
        lb.healthCheck()   // 定时器触发时,执行健康检查
    case <-lb.stopChan:
        return             // 收到停止信号时,退出循环
    }
}
  • case <-ticker.C:等待 ticker 的定时信号,然后做健康检查。
  • case <-lb.stopChan:等待停止信号,一旦收到(通道被关闭),立即退出循环,goroutine 结束。

select 会阻塞直到其中一个 case 可以执行。如果多个 case 同时满足,会随机选择一个。

四、完整示例代码

代码如下:

package main

import (
	"hash/fnv"
	"io"
	"log"
	"math/rand"
	"net"
	"net/http"
	"net/url"
	"sync"
	"sync/atomic"
	"time"
)

type Backend struct {
	URL         *url.URL
	Alive       int32 // 1=存活,0=不存活
	Connections int64
}

func (b *Backend) IsAlive() bool {
	return atomic.LoadInt32(&b.Alive) == 1
}

func (b *Backend) SetAlive(alive bool) {
	if alive {
		atomic.StoreInt32(&b.Alive, 1)
	} else {
		atomic.StoreInt32(&b.Alive, 0)
	}
}

type LoadBalancer struct {
	backends []*Backend
	current  uint64
	mu       sync.RWMutex
	stopChan chan struct{}
	randSrc  *rand.Rand
	randMu   sync.Mutex
}

func NewLoadBalancer(urls []string) *LoadBalancer {
	var backends []*Backend
	for _, u := range urls {
		parsed, _ := url.Parse(u)
		backends = append(backends, &Backend{
			URL:         parsed,
			Alive:       1, // 默认存活
			Connections: 0,
		})
	}
	return &LoadBalancer{
		backends: backends,
		current:  0,
		stopChan: make(chan struct{}),
		randSrc:  rand.New(rand.NewSource(time.Now().UnixNano())),
	}
}

// 轮询
func (lb *LoadBalancer) NextRoundRobin() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	start := atomic.AddUint64(&lb.current, 1) % uint64(n)
	for i := 0; i < n; i++ {
		idx := (int(start) + i) % n
		backend := lb.backends[idx]
		if backend.IsAlive() {
			return backend
		}
	}
	return nil
}

// 随机
func (lb *LoadBalancer) NextRandom() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	// 最多尝试 n 次
	for i := 0; i < n; i++ {
		lb.randMu.Lock()
		idx := lb.randSrc.Intn(n)
		lb.randMu.Unlock()
		backend := lb.backends[idx]
		if backend.IsAlive() {
			return backend
		}
	}
	return nil
}

// 最少连接数
func (lb *LoadBalancer) NextLeastConnections() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	var selected *Backend
	var minConns int64 = 1<<63 - 1

	for _, b := range lb.backends {
		if !b.IsAlive() {
			continue
		}
		conns := atomic.LoadInt64(&b.Connections)
		if conns < minConns {
			minConns = conns
			selected = b
		}
	}
	return selected
}

// IP 哈希
func (lb *LoadBalancer) NextIPHash(ip string) *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	h := fnv.New32a()
	h.Write([]byte(ip))
	hash := h.Sum32()
	idx := int(hash) % n

	// 如果选中的后端不存活,可以尝试线性探测下一个存活的后端
	for i := 0; i < n; i++ {
		candidate := lb.backends[(idx+i)%n]
		if candidate.IsAlive() {
			return candidate
		}
	}
	return nil
}

// 健康检查
func (lb *LoadBalancer) StartHealthCheck(interval time.Duration) {
	go func() {
		ticker := time.NewTicker(interval)
		defer ticker.Stop()
		for {
			select {
			case <-ticker.C:
				lb.healthCheck()
			case <-lb.stopChan:
				return
			}
		}
	}()
}

func (lb *LoadBalancer) Stop() {
	close(lb.stopChan)
}

func (lb *LoadBalancer) healthCheck() {
	lb.mu.RLock()
	backends := lb.backends
	lb.mu.RUnlock()

	var wg sync.WaitGroup
	for _, b := range backends {
		wg.Add(1)
		go func(backend *Backend) {
			defer wg.Done()
			alive := checkBackendHealth(backend.URL)
			backend.SetAlive(alive)
			if !alive {
				log.Printf("Backend %s is DOWN", backend.URL.String())
			} else {
				// 可以记录恢复日志,但避免太多
				// log.Printf("Backend %s is UP", backend.URL.String())
			}
		}(b)
	}
	wg.Wait()
}

func checkBackendHealth(u *url.URL) bool {
	client := http.Client{
		Timeout: 2 * time.Second,
	}
	// 假设健康检查路径为 /health
	healthURL := u.ResolveReference(&url.URL{Path: "/health"})
	resp, err := client.Head(healthURL.String())
	if err != nil {
		return false
	}
	defer resp.Body.Close()
	return resp.StatusCode >= 200 && resp.StatusCode < 300
}

// 反向代理处理器
func ReverseProxy(lb *LoadBalancer, strategy string) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		var backend *Backend
		switch strategy {
		case "round-robin":
			backend = lb.NextRoundRobin()
		case "random":
			backend = lb.NextRandom()
		case "least-conn":
			backend = lb.NextLeastConnections()
		case "ip-hash":
			ip, _, _ := net.SplitHostPort(r.RemoteAddr)
			backend = lb.NextIPHash(ip)
		default:
			http.Error(w, "unknown strategy", http.StatusInternalServerError)
			return
		}

		if backend == nil {
			http.Error(w, "no available backend", http.StatusServiceUnavailable)
			return
		}

		atomic.AddInt64(&backend.Connections, 1)
		defer atomic.AddInt64(&backend.Connections, -1)

		proxyURL := backend.URL.ResolveReference(r.URL)
		req, err := http.NewRequest(r.Method, proxyURL.String(), r.Body)
		if err != nil {
			http.Error(w, err.Error(), http.StatusInternalServerError)
			return
		}

		req.Header = r.Header.Clone()
		req.Host = backend.URL.Host

		client := http.DefaultClient
		resp, err := client.Do(req)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadGateway)
			return
		}
		defer resp.Body.Close()

		for key, values := range resp.Header {
			for _, value := range values {
				w.Header().Add(key, value)
			}
		}
		w.WriteHeader(resp.StatusCode)
		io.Copy(w, resp.Body)
	}
}

func main() {
	backendURLs := []string{
		"http://127.0.0.1:8001",
		"http://127.0.0.1:8002",
		"http://127.0.0.1:8003",
	}
	lb := NewLoadBalancer(backendURLs)

	// 启动健康检查,每 5 秒一次
	lb.StartHealthCheck(5 * time.Second)

	strategy := "round-robin" // 可以改成其他策略测试

	http.HandleFunc("/", ReverseProxy(lb, strategy))

	log.Println("Starting proxy server on :8080")
	log.Fatal(http.ListenAndServe(":8080", nil))
}

发表评论