timeout client connections

This commit is contained in:
chrislu
2025-09-03 15:49:27 -07:00
parent 323193cf8c
commit 191bad0a21

View File

@@ -91,14 +91,15 @@ const (
// PostgreSQL server configuration
type PostgreSQLServerConfig struct {
Host string
Port int
AuthMethod AuthMethod
Users map[string]string
TLSConfig *tls.Config
MaxConns int
IdleTimeout time.Duration
Database string
Host string
Port int
AuthMethod AuthMethod
Users map[string]string
TLSConfig *tls.Config
MaxConns int
IdleTimeout time.Duration
StartupTimeout time.Duration // Timeout for client startup handshake
Database string
}
// PostgreSQL server
@@ -177,6 +178,9 @@ func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*Po
if config.IdleTimeout <= 0 {
config.IdleTimeout = time.Hour
}
if config.StartupTimeout <= 0 {
config.StartupTimeout = 30 * time.Second
}
// Create SQL engine with PostgreSQL parser for proper dialect compatibility
// Use PostgreSQL parser since we're implementing PostgreSQL wire protocol
@@ -325,12 +329,19 @@ func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
s.sessionMux.Unlock()
}()
glog.Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
// Handle startup
err := s.handleStartup(session)
if err != nil {
glog.Errorf("Startup failed for connection %d: %v", connID, err)
// Handle common disconnection scenarios more gracefully
if strings.Contains(err.Error(), "client disconnected") {
glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err)
} else if strings.Contains(err.Error(), "timeout") {
glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
} else {
glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
}
return
}
@@ -361,19 +372,42 @@ func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
// handleStartup processes the PostgreSQL startup sequence
func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
// Set a startup timeout to prevent hanging connections
startupTimeout := s.config.StartupTimeout
session.conn.SetReadDeadline(time.Now().Add(startupTimeout))
defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
for {
// Read startup message
// Read startup message length
length := make([]byte, 4)
_, err := io.ReadFull(session.reader, length)
if err != nil {
return err
if err == io.EOF {
// Client disconnected during startup - this is common for health checks
return fmt.Errorf("client disconnected during startup handshake")
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("startup handshake timeout after %v", startupTimeout)
}
return fmt.Errorf("failed to read message length during startup: %v", err)
}
msgLength := binary.BigEndian.Uint32(length) - 4
if msgLength > 10000 { // Reasonable limit for startup messages
return fmt.Errorf("startup message too large: %d bytes", msgLength)
}
// Read startup message content
msg := make([]byte, msgLength)
_, err = io.ReadFull(session.reader, msg)
if err != nil {
return err
if err == io.EOF {
return fmt.Errorf("client disconnected while reading startup message")
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("startup message read timeout")
}
return fmt.Errorf("failed to read startup message: %v", err)
}
// Parse protocol version