mirror of
https://github.com/seaweedfs/seaweedfs.git
synced 2025-09-20 15:09:23 +08:00
add postgres protocol
This commit is contained in:
640
weed/server/postgres/server.go
Normal file
640
weed/server/postgres/server.go
Normal file
@@ -0,0 +1,640 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/query/engine"
|
||||
)
|
||||
|
||||
// PostgreSQL protocol constants
|
||||
const (
|
||||
// Message types from client
|
||||
PG_MSG_STARTUP = 0x00
|
||||
PG_MSG_QUERY = 'Q'
|
||||
PG_MSG_PARSE = 'P'
|
||||
PG_MSG_BIND = 'B'
|
||||
PG_MSG_EXECUTE = 'E'
|
||||
PG_MSG_DESCRIBE = 'D'
|
||||
PG_MSG_CLOSE = 'C'
|
||||
PG_MSG_FLUSH = 'H'
|
||||
PG_MSG_SYNC = 'S'
|
||||
PG_MSG_TERMINATE = 'X'
|
||||
PG_MSG_PASSWORD = 'p'
|
||||
|
||||
// Response types to client
|
||||
PG_RESP_AUTH_OK = 'R'
|
||||
PG_RESP_BACKEND_KEY = 'K'
|
||||
PG_RESP_PARAMETER = 'S'
|
||||
PG_RESP_READY = 'Z'
|
||||
PG_RESP_COMMAND = 'C'
|
||||
PG_RESP_DATA_ROW = 'D'
|
||||
PG_RESP_ROW_DESC = 'T'
|
||||
PG_RESP_PARSE_COMPLETE = '1'
|
||||
PG_RESP_BIND_COMPLETE = '2'
|
||||
PG_RESP_CLOSE_COMPLETE = '3'
|
||||
PG_RESP_ERROR = 'E'
|
||||
PG_RESP_NOTICE = 'N'
|
||||
|
||||
// Transaction states
|
||||
PG_TRANS_IDLE = 'I'
|
||||
PG_TRANS_INTRANS = 'T'
|
||||
PG_TRANS_ERROR = 'E'
|
||||
|
||||
// Authentication methods
|
||||
AUTH_OK = 0
|
||||
AUTH_CLEAR = 3
|
||||
AUTH_MD5 = 5
|
||||
AUTH_TRUST = 10
|
||||
|
||||
// PostgreSQL data types
|
||||
PG_TYPE_BOOL = 16
|
||||
PG_TYPE_INT8 = 20
|
||||
PG_TYPE_INT4 = 23
|
||||
PG_TYPE_TEXT = 25
|
||||
PG_TYPE_FLOAT8 = 701
|
||||
PG_TYPE_VARCHAR = 1043
|
||||
PG_TYPE_TIMESTAMP = 1114
|
||||
PG_TYPE_JSON = 114
|
||||
|
||||
// Default values
|
||||
DEFAULT_POSTGRES_PORT = 5432
|
||||
)
|
||||
|
||||
// Authentication method type
|
||||
type AuthMethod int
|
||||
|
||||
const (
|
||||
AuthTrust AuthMethod = iota
|
||||
AuthPassword
|
||||
AuthMD5
|
||||
)
|
||||
|
||||
// PostgreSQL server configuration
|
||||
type PostgreSQLServerConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
AuthMethod AuthMethod
|
||||
Users map[string]string
|
||||
TLSConfig *tls.Config
|
||||
MaxConns int
|
||||
IdleTimeout time.Duration
|
||||
Database string
|
||||
}
|
||||
|
||||
// PostgreSQL server
|
||||
type PostgreSQLServer struct {
|
||||
config *PostgreSQLServerConfig
|
||||
listener net.Listener
|
||||
sqlEngine *engine.SQLEngine
|
||||
sessions map[uint32]*PostgreSQLSession
|
||||
sessionMux sync.RWMutex
|
||||
shutdown chan struct{}
|
||||
wg sync.WaitGroup
|
||||
translator *PostgreSQLTranslator
|
||||
nextConnID uint32
|
||||
}
|
||||
|
||||
// PostgreSQL session
|
||||
type PostgreSQLSession struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
authenticated bool
|
||||
username string
|
||||
database string
|
||||
parameters map[string]string
|
||||
preparedStmts map[string]*PreparedStatement
|
||||
portals map[string]*Portal
|
||||
transactionState byte
|
||||
processID uint32
|
||||
secretKey uint32
|
||||
created time.Time
|
||||
lastActivity time.Time
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Prepared statement
|
||||
type PreparedStatement struct {
|
||||
Name string
|
||||
Query string
|
||||
ParamTypes []uint32
|
||||
Fields []FieldDescription
|
||||
}
|
||||
|
||||
// Portal (cursor)
|
||||
type Portal struct {
|
||||
Name string
|
||||
Statement string
|
||||
Parameters [][]byte
|
||||
Suspended bool
|
||||
}
|
||||
|
||||
// Field description
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
TableOID uint32
|
||||
AttrNum int16
|
||||
TypeOID uint32
|
||||
TypeSize int16
|
||||
TypeMod int32
|
||||
Format int16
|
||||
}
|
||||
|
||||
// NewPostgreSQLServer creates a new PostgreSQL protocol server
|
||||
func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*PostgreSQLServer, error) {
|
||||
if config.Port <= 0 {
|
||||
config.Port = DEFAULT_POSTGRES_PORT
|
||||
}
|
||||
if config.Host == "" {
|
||||
config.Host = "localhost"
|
||||
}
|
||||
if config.Database == "" {
|
||||
config.Database = "default"
|
||||
}
|
||||
if config.MaxConns <= 0 {
|
||||
config.MaxConns = 100
|
||||
}
|
||||
if config.IdleTimeout <= 0 {
|
||||
config.IdleTimeout = time.Hour
|
||||
}
|
||||
|
||||
// Create SQL engine
|
||||
sqlEngine := engine.NewSQLEngine(masterAddr)
|
||||
|
||||
// Initialize translator
|
||||
translator := &PostgreSQLTranslator{
|
||||
systemQueries: make(map[string]string),
|
||||
}
|
||||
translator.initSystemQueries()
|
||||
|
||||
server := &PostgreSQLServer{
|
||||
config: config,
|
||||
sqlEngine: sqlEngine,
|
||||
sessions: make(map[uint32]*PostgreSQLSession),
|
||||
shutdown: make(chan struct{}),
|
||||
translator: translator,
|
||||
nextConnID: 1,
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// Start begins listening for PostgreSQL connections
|
||||
func (s *PostgreSQLServer) Start() error {
|
||||
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
|
||||
|
||||
var listener net.Listener
|
||||
var err error
|
||||
|
||||
if s.config.TLSConfig != nil {
|
||||
listener, err = tls.Listen("tcp", addr, s.config.TLSConfig)
|
||||
glog.Infof("PostgreSQL Server with TLS listening on %s", addr)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
glog.Infof("PostgreSQL Server listening on %s", addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start PostgreSQL server on %s: %v", addr, err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
|
||||
// Start accepting connections
|
||||
s.wg.Add(1)
|
||||
go s.acceptConnections()
|
||||
|
||||
// Start cleanup routine
|
||||
s.wg.Add(1)
|
||||
go s.cleanupSessions()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the PostgreSQL server
|
||||
func (s *PostgreSQLServer) Stop() error {
|
||||
close(s.shutdown)
|
||||
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
// Close all sessions
|
||||
s.sessionMux.Lock()
|
||||
for _, session := range s.sessions {
|
||||
session.close()
|
||||
}
|
||||
s.sessions = make(map[uint32]*PostgreSQLSession)
|
||||
s.sessionMux.Unlock()
|
||||
|
||||
s.wg.Wait()
|
||||
glog.Infof("PostgreSQL Server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptConnections handles incoming PostgreSQL connections
|
||||
func (s *PostgreSQLServer) acceptConnections() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.shutdown:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.shutdown:
|
||||
return
|
||||
default:
|
||||
glog.Errorf("Failed to accept PostgreSQL connection: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check connection limit
|
||||
s.sessionMux.RLock()
|
||||
sessionCount := len(s.sessions)
|
||||
s.sessionMux.RUnlock()
|
||||
|
||||
if sessionCount >= s.config.MaxConns {
|
||||
glog.Warningf("Maximum connections reached (%d), rejecting connection from %s",
|
||||
s.config.MaxConns, conn.RemoteAddr())
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection processes a single PostgreSQL connection
|
||||
func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
// Generate unique connection ID
|
||||
connID := s.generateConnectionID()
|
||||
secretKey := s.generateSecretKey()
|
||||
|
||||
// Create session
|
||||
session := &PostgreSQLSession{
|
||||
conn: conn,
|
||||
reader: bufio.NewReader(conn),
|
||||
writer: bufio.NewWriter(conn),
|
||||
authenticated: false,
|
||||
database: s.config.Database,
|
||||
parameters: make(map[string]string),
|
||||
preparedStmts: make(map[string]*PreparedStatement),
|
||||
portals: make(map[string]*Portal),
|
||||
transactionState: PG_TRANS_IDLE,
|
||||
processID: connID,
|
||||
secretKey: secretKey,
|
||||
created: time.Now(),
|
||||
lastActivity: time.Now(),
|
||||
}
|
||||
|
||||
// Register session
|
||||
s.sessionMux.Lock()
|
||||
s.sessions[connID] = session
|
||||
s.sessionMux.Unlock()
|
||||
|
||||
// Clean up on exit
|
||||
defer func() {
|
||||
s.sessionMux.Lock()
|
||||
delete(s.sessions, connID)
|
||||
s.sessionMux.Unlock()
|
||||
}()
|
||||
|
||||
glog.Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
|
||||
|
||||
// Handle startup
|
||||
err := s.handleStartup(session)
|
||||
if err != nil {
|
||||
glog.Errorf("Startup failed for connection %d: %v", connID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle messages
|
||||
for {
|
||||
select {
|
||||
case <-s.shutdown:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Set read timeout
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
err := s.handleMessage(session)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
glog.Infof("PostgreSQL client disconnected (ID: %d)", connID)
|
||||
} else {
|
||||
glog.Errorf("Error handling PostgreSQL message (ID: %d): %v", connID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
session.lastActivity = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// handleStartup processes the PostgreSQL startup sequence
|
||||
func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
|
||||
// Read startup message
|
||||
length := make([]byte, 4)
|
||||
_, err := io.ReadFull(session.reader, length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msgLength := binary.BigEndian.Uint32(length) - 4
|
||||
msg := make([]byte, msgLength)
|
||||
_, err = io.ReadFull(session.reader, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse startup message
|
||||
protocolVersion := binary.BigEndian.Uint32(msg[0:4])
|
||||
if protocolVersion != 196608 { // PostgreSQL protocol version 3.0
|
||||
return fmt.Errorf("unsupported protocol version: %d", protocolVersion)
|
||||
}
|
||||
|
||||
// Parse parameters
|
||||
params := strings.Split(string(msg[4:]), "\x00")
|
||||
for i := 0; i < len(params)-1; i += 2 {
|
||||
if params[i] == "user" {
|
||||
session.username = params[i+1]
|
||||
} else if params[i] == "database" {
|
||||
session.database = params[i+1]
|
||||
}
|
||||
session.parameters[params[i]] = params[i+1]
|
||||
}
|
||||
|
||||
// Handle authentication
|
||||
err = s.handleAuthentication(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send parameter status messages
|
||||
err = s.sendParameterStatus(session, "server_version", "14.0 (SeaweedFS)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.sendParameterStatus(session, "server_encoding", "UTF8")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.sendParameterStatus(session, "client_encoding", "UTF8")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.sendParameterStatus(session, "DateStyle", "ISO, MDY")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.sendParameterStatus(session, "integer_datetimes", "on")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send backend key data
|
||||
err = s.sendBackendKeyData(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send ready for query
|
||||
err = s.sendReadyForQuery(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.authenticated = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAuthentication processes authentication
|
||||
func (s *PostgreSQLServer) handleAuthentication(session *PostgreSQLSession) error {
|
||||
switch s.config.AuthMethod {
|
||||
case AuthTrust:
|
||||
return s.sendAuthenticationOk(session)
|
||||
case AuthPassword:
|
||||
return s.handlePasswordAuth(session)
|
||||
case AuthMD5:
|
||||
return s.handleMD5Auth(session)
|
||||
default:
|
||||
return fmt.Errorf("unsupported authentication method")
|
||||
}
|
||||
}
|
||||
|
||||
// sendAuthenticationOk sends authentication OK message
|
||||
func (s *PostgreSQLServer) sendAuthenticationOk(session *PostgreSQLSession) error {
|
||||
msg := make([]byte, 8)
|
||||
msg[0] = PG_RESP_AUTH_OK
|
||||
binary.BigEndian.PutUint32(msg[1:5], 8)
|
||||
binary.BigEndian.PutUint32(msg[5:9], AUTH_OK)
|
||||
|
||||
_, err := session.writer.Write(msg)
|
||||
if err == nil {
|
||||
err = session.writer.Flush()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handlePasswordAuth handles clear password authentication
|
||||
func (s *PostgreSQLServer) handlePasswordAuth(session *PostgreSQLSession) error {
|
||||
// Send password request
|
||||
msg := make([]byte, 8)
|
||||
msg[0] = PG_RESP_AUTH_OK
|
||||
binary.BigEndian.PutUint32(msg[1:5], 8)
|
||||
binary.BigEndian.PutUint32(msg[5:9], AUTH_CLEAR)
|
||||
|
||||
_, err := session.writer.Write(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = session.writer.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read password response
|
||||
msgType := make([]byte, 1)
|
||||
_, err = io.ReadFull(session.reader, msgType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType[0] != PG_MSG_PASSWORD {
|
||||
return fmt.Errorf("expected password message, got %c", msgType[0])
|
||||
}
|
||||
|
||||
length := make([]byte, 4)
|
||||
_, err = io.ReadFull(session.reader, length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msgLength := binary.BigEndian.Uint32(length) - 4
|
||||
password := make([]byte, msgLength)
|
||||
_, err = io.ReadFull(session.reader, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify password
|
||||
expectedPassword, exists := s.config.Users[session.username]
|
||||
if !exists || string(password[:len(password)-1]) != expectedPassword { // Remove null terminator
|
||||
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
|
||||
}
|
||||
|
||||
return s.sendAuthenticationOk(session)
|
||||
}
|
||||
|
||||
// handleMD5Auth handles MD5 password authentication
|
||||
func (s *PostgreSQLServer) handleMD5Auth(session *PostgreSQLSession) error {
|
||||
// Generate salt
|
||||
salt := make([]byte, 4)
|
||||
_, err := rand.Read(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send MD5 request
|
||||
msg := make([]byte, 12)
|
||||
msg[0] = PG_RESP_AUTH_OK
|
||||
binary.BigEndian.PutUint32(msg[1:5], 12)
|
||||
binary.BigEndian.PutUint32(msg[5:9], AUTH_MD5)
|
||||
copy(msg[9:13], salt)
|
||||
|
||||
_, err = session.writer.Write(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = session.writer.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read password response
|
||||
msgType := make([]byte, 1)
|
||||
_, err = io.ReadFull(session.reader, msgType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType[0] != PG_MSG_PASSWORD {
|
||||
return fmt.Errorf("expected password message, got %c", msgType[0])
|
||||
}
|
||||
|
||||
length := make([]byte, 4)
|
||||
_, err = io.ReadFull(session.reader, length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msgLength := binary.BigEndian.Uint32(length) - 4
|
||||
response := make([]byte, msgLength)
|
||||
_, err = io.ReadFull(session.reader, response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify MD5 hash
|
||||
expectedPassword, exists := s.config.Users[session.username]
|
||||
if !exists {
|
||||
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
|
||||
}
|
||||
|
||||
// Calculate expected hash: md5(md5(password + username) + salt)
|
||||
inner := md5.Sum([]byte(expectedPassword + session.username))
|
||||
expected := fmt.Sprintf("md5%x", md5.Sum(append([]byte(fmt.Sprintf("%x", inner)), salt...)))
|
||||
|
||||
if string(response[:len(response)-1]) != expected { // Remove null terminator
|
||||
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
|
||||
}
|
||||
|
||||
return s.sendAuthenticationOk(session)
|
||||
}
|
||||
|
||||
// generateConnectionID generates a unique connection ID
|
||||
func (s *PostgreSQLServer) generateConnectionID() uint32 {
|
||||
s.sessionMux.Lock()
|
||||
defer s.sessionMux.Unlock()
|
||||
id := s.nextConnID
|
||||
s.nextConnID++
|
||||
return id
|
||||
}
|
||||
|
||||
// generateSecretKey generates a secret key for the connection
|
||||
func (s *PostgreSQLServer) generateSecretKey() uint32 {
|
||||
key := make([]byte, 4)
|
||||
rand.Read(key)
|
||||
return binary.BigEndian.Uint32(key)
|
||||
}
|
||||
|
||||
// close marks the session as closed
|
||||
func (s *PostgreSQLSession) close() {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.conn != nil {
|
||||
s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupSessions periodically cleans up idle sessions
|
||||
func (s *PostgreSQLServer) cleanupSessions() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.shutdown:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.cleanupIdleSessions()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupIdleSessions removes sessions that have been idle too long
|
||||
func (s *PostgreSQLServer) cleanupIdleSessions() {
|
||||
now := time.Now()
|
||||
|
||||
s.sessionMux.Lock()
|
||||
defer s.sessionMux.Unlock()
|
||||
|
||||
for id, session := range s.sessions {
|
||||
if now.Sub(session.lastActivity) > s.config.IdleTimeout {
|
||||
glog.Infof("Closing idle PostgreSQL session %d", id)
|
||||
session.close()
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAddress returns the server address
|
||||
func (s *PostgreSQLServer) GetAddress() string {
|
||||
return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
|
||||
}
|
Reference in New Issue
Block a user