基于Golang实现反向代理服务器(4)定制Transport优化连接池

回顾与目标:

在之前的代码中,我们使用 http.Transport 来实现 HTTP 的连接。

http.Transport 是 Go 中 HTTP 客户端的底层实现,负责:

  • 支持代理、TLS 等
  • 管理 TCP 连接(创建、复用、关闭)
  • 设置超时

默认的 http.DefaultClient 使用 http.DefaultTransport,其配置如下(Go 1.16+):

var DefaultTransport = &http.Transport{
    MaxIdleConns:        100,
    MaxIdleConnsPerHost: 2,   // 关键:每个主机最多保持 2 个空闲连接
    IdleConnTimeout:     90 * time.Second,
    TLSHandshakeTimeout: 10 * time.Second,
    ExpectContinueTimeout: 1 * time.Second,
    // ... 其他字段
}

这里最影响性能的是 MaxIdleConnsPerHost = 2。如果你的代理频繁请求同一批后端,每个后端最多只能复用 2 个连接,当并发超过 2 时,就必须创建新连接,用完关闭,导致大量 TIME_WAIT 和端口占用。

因此,我们希望对此做出优化

  • 提高连接复用率:增大 MaxIdleConnsPerHost,让每个后端可以保持更多空闲连接,减少新建连接的开销。
  • 控制总空闲连接数:设置 MaxIdleConns 避免占用过多内存。
  • 设置合理的超时IdleConnTimeout 控制空闲连接最大存活时间,避免长时间占用;ResponseHeaderTimeout 等防止后端响应慢导致代理卡死。
  • 避免端口耗尽:通过复用连接,减少主动关闭的连接,从而减少 TIME_WAIT 状态的连接数(每个短暂连接都会在客户端产生一个 TIME_WAIT,耗尽本地端口)。

实现方式

我们将创建一个自定义的 http.Client,并配置它的 Transport。为了复用,可以在 main 函数中创建一次,然后传递给 ReverseProxy

代码如下:

    // 自定义 Transport
transport := &http.Transport{
MaxIdleConns: 100, // 最大空闲连接总数
MaxIdleConnsPerHost: 10, // 每个后端最多保持 10 个空闲连接
IdleConnTimeout: 90 * time.Second, // 空闲连接超时
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 5 * time.Second, // 等待后端响应头的超时
ExpectContinueTimeout: 1 * time.Second,
// 还可以设置其他字段,如 DialContext 等
}

// 自定义 Client,使用这个 Transport
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second, // 整个请求的超时(包括连接、读取等)
}

然后修改 ReverseProxy 函数,增加一个 client *http.Client 参数,并使用它代替 http.DefaultClient

完整实例代码

代码如下:

package main

import (
	"hash/fnv"
	"io"
	"log"
	"math/rand"
	"net"
	"net/http"
	"net/url"
	"sync"
	"sync/atomic"
	"time"
)

type Backend struct {
	URL         *url.URL
	Alive       int32 // 1=存活,0=不存活
	Connections int64
}

func (b *Backend) IsAlive() bool {
	return atomic.LoadInt32(&b.Alive) == 1
}

func (b *Backend) SetAlive(alive bool) {
	if alive {
		atomic.StoreInt32(&b.Alive, 1)
	} else {
		atomic.StoreInt32(&b.Alive, 0)
	}
}

type LoadBalancer struct {
	backends []*Backend
	current  uint64
	mu       sync.RWMutex
	stopChan chan struct{}
	randSrc  *rand.Rand
	randMu   sync.Mutex
}

func NewLoadBalancer(urls []string) *LoadBalancer {
	var backends []*Backend
	for _, u := range urls {
		parsed, _ := url.Parse(u)
		backends = append(backends, &Backend{
			URL:         parsed,
			Alive:       1, // 默认存活
			Connections: 0,
		})
	}
	return &LoadBalancer{
		backends: backends,
		current:  0,
		stopChan: make(chan struct{}),
		randSrc:  rand.New(rand.NewSource(time.Now().UnixNano())),
	}
}

// 轮询
func (lb *LoadBalancer) NextRoundRobin() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	start := atomic.AddUint64(&lb.current, 1) % uint64(n)
	for i := 0; i < n; i++ {
		idx := (int(start) + i) % n
		backend := lb.backends[idx]
		if backend.IsAlive() {
			return backend
		}
	}
	return nil
}

// 随机
func (lb *LoadBalancer) NextRandom() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	// 最多尝试 n 次
	for i := 0; i < n; i++ {
		lb.randMu.Lock()
		idx := lb.randSrc.Intn(n)
		lb.randMu.Unlock()
		backend := lb.backends[idx]
		if backend.IsAlive() {
			return backend
		}
	}
	return nil
}

// 最少连接数
func (lb *LoadBalancer) NextLeastConnections() *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	var selected *Backend
	var minConns int64 = 1<<63 - 1

	for _, b := range lb.backends {
		if !b.IsAlive() {
			continue
		}
		conns := atomic.LoadInt64(&b.Connections)
		if conns < minConns {
			minConns = conns
			selected = b
		}
	}
	return selected
}

// IP 哈希
func (lb *LoadBalancer) NextIPHash(ip string) *Backend {
	lb.mu.RLock()
	defer lb.mu.RUnlock()

	n := len(lb.backends)
	if n == 0 {
		return nil
	}

	h := fnv.New32a()
	h.Write([]byte(ip))
	hash := h.Sum32()
	idx := int(hash) % n

	// 如果选中的后端不存活,可以尝试线性探测下一个存活的后端
	for i := 0; i < n; i++ {
		candidate := lb.backends[(idx+i)%n]
		if candidate.IsAlive() {
			return candidate
		}
	}
	return nil
}

// 健康检查
func (lb *LoadBalancer) StartHealthCheck(interval time.Duration) {
	go func() {
		ticker := time.NewTicker(interval)
		defer ticker.Stop()
		for {
			select {
			case <-ticker.C:
				lb.healthCheck()
			case <-lb.stopChan:
				return
			}
		}
	}()
}

func (lb *LoadBalancer) Stop() {
	close(lb.stopChan)
}

func (lb *LoadBalancer) healthCheck() {
	lb.mu.RLock()
	backends := lb.backends
	lb.mu.RUnlock()

	var wg sync.WaitGroup
	for _, b := range backends {
		wg.Add(1)
		go func(backend *Backend) {
			defer wg.Done()
			alive := checkBackendHealth(backend.URL)
			backend.SetAlive(alive)
			if !alive {
				log.Printf("Backend %s is DOWN", backend.URL.String())
			} else {
				// 可以记录恢复日志,但避免太多
				// log.Printf("Backend %s is UP", backend.URL.String())
			}
		}(b)
	}
	wg.Wait()
}

func checkBackendHealth(u *url.URL) bool {
	client := http.Client{
		Timeout: 2 * time.Second,
	}
	// 假设健康检查路径为 /health
	healthURL := u.ResolveReference(&url.URL{Path: "/health"})
	resp, err := client.Head(healthURL.String())
	if err != nil {
		return false
	}
	defer resp.Body.Close()
	return resp.StatusCode >= 200 && resp.StatusCode < 300
}

// 反向代理处理器
func ReverseProxy(lb *LoadBalancer, strategy string, client *http.Client) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		var backend *Backend
		switch strategy {
		case "round-robin":
			backend = lb.NextRoundRobin()
		case "random":
			backend = lb.NextRandom()
		case "least-conn":
			backend = lb.NextLeastConnections()
		case "ip-hash":
			ip, _, _ := net.SplitHostPort(r.RemoteAddr)
			backend = lb.NextIPHash(ip)
		default:
			http.Error(w, "unknown strategy", http.StatusInternalServerError)
			return
		}

		if backend == nil {
			http.Error(w, "no available backend", http.StatusServiceUnavailable)
			return
		}

		atomic.AddInt64(&backend.Connections, 1)
		defer atomic.AddInt64(&backend.Connections, -1)

		proxyURL := backend.URL.ResolveReference(r.URL)
		req, err := http.NewRequest(r.Method, proxyURL.String(), r.Body)
		if err != nil {
			http.Error(w, err.Error(), http.StatusInternalServerError)
			return
		}

		req.Header = r.Header.Clone()
		req.Host = backend.URL.Host

		// 使用自定义 client
		resp, err := client.Do(req)
		if err != nil {
			http.Error(w, err.Error(), http.StatusBadGateway)
			return
		}
		defer resp.Body.Close()

		for key, values := range resp.Header {
			for _, value := range values {
				w.Header().Add(key, value)
			}
		}
		w.WriteHeader(resp.StatusCode)
		io.Copy(w, resp.Body)
	}
}

func main() {
	backendURLs := []string{
		"http://127.0.0.1:8001",
		"http://127.0.0.1:8002",
		"http://127.0.0.1:8003",
	}
	lb := NewLoadBalancer(backendURLs)

	// 启动健康检查,每 5 秒一次
	lb.StartHealthCheck(5 * time.Second)

	strategy := "round-robin" // 可以改成其他策略测试

	transport := &http.Transport{
		MaxIdleConns:          100,
		MaxIdleConnsPerHost:   10,
		IdleConnTimeout:       90 * time.Second,
		ResponseHeaderTimeout: 5 * time.Second,
		TLSHandshakeTimeout:   10 * time.Second,
		ExpectContinueTimeout: 1 * time.Second,
	}
	client := &http.Client{
		Transport: transport,
		Timeout:   30 * time.Second,
	}

	http.HandleFunc("/", ReverseProxy(lb, strategy, client))

	log.Println("Starting proxy server on :8080")
	log.Fatal(http.ListenAndServe(":8080", nil))
}

发表评论