mirror of
https://github.com/seaweedfs/seaweedfs.git
synced 2025-11-24 08:46:54 +08:00
adding cors support (#6987)
* adding cors support * address some comments * optimize matchesWildcard * address comments * fix for tests * address comments * address comments * address comments * path building * refactor * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * address comment Service-level responses need both Access-Control-Allow-Methods and Access-Control-Allow-Headers. After setting Access-Control-Allow-Origin and Access-Control-Expose-Headers, also set Access-Control-Allow-Methods: * and Access-Control-Allow-Headers: * so service endpoints satisfy CORS preflight requirements. * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix * refactor * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_server.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * simplify * add cors tests * fix tests * fix tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
649
weed/s3api/cors/cors.go
Normal file
649
weed/s3api/cors/cors.go
Normal file
@@ -0,0 +1,649 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
|
||||
)
|
||||
|
||||
// S3 metadata file name constant to avoid typos and reduce duplication
|
||||
const S3MetadataFileName = ".s3metadata"
|
||||
|
||||
// CORSRule represents a single CORS rule
|
||||
type CORSRule struct {
|
||||
ID string `xml:"ID,omitempty" json:"ID,omitempty"`
|
||||
AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"`
|
||||
AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"`
|
||||
AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"`
|
||||
ExposeHeaders []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"`
|
||||
MaxAgeSeconds *int `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"`
|
||||
}
|
||||
|
||||
// CORSConfiguration represents the CORS configuration for a bucket
|
||||
type CORSConfiguration struct {
|
||||
XMLName xml.Name `xml:"CORSConfiguration"`
|
||||
CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"`
|
||||
}
|
||||
|
||||
// CORSRequest represents a CORS request
|
||||
type CORSRequest struct {
|
||||
Origin string
|
||||
Method string
|
||||
RequestHeaders []string
|
||||
IsPreflightRequest bool
|
||||
AccessControlRequestMethod string
|
||||
AccessControlRequestHeaders []string
|
||||
}
|
||||
|
||||
// CORSResponse represents CORS response headers
|
||||
type CORSResponse struct {
|
||||
AllowOrigin string
|
||||
AllowMethods string
|
||||
AllowHeaders string
|
||||
ExposeHeaders string
|
||||
MaxAge string
|
||||
AllowCredentials bool
|
||||
}
|
||||
|
||||
// ValidateConfiguration validates a CORS configuration
|
||||
func ValidateConfiguration(config *CORSConfiguration) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("CORS configuration cannot be nil")
|
||||
}
|
||||
|
||||
if len(config.CORSRules) == 0 {
|
||||
return fmt.Errorf("CORS configuration must have at least one rule")
|
||||
}
|
||||
|
||||
if len(config.CORSRules) > 100 {
|
||||
return fmt.Errorf("CORS configuration cannot have more than 100 rules")
|
||||
}
|
||||
|
||||
for i, rule := range config.CORSRules {
|
||||
if err := validateRule(&rule); err != nil {
|
||||
return fmt.Errorf("invalid CORS rule at index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRule validates a single CORS rule
|
||||
func validateRule(rule *CORSRule) error {
|
||||
if len(rule.AllowedMethods) == 0 {
|
||||
return fmt.Errorf("AllowedMethods cannot be empty")
|
||||
}
|
||||
|
||||
if len(rule.AllowedOrigins) == 0 {
|
||||
return fmt.Errorf("AllowedOrigins cannot be empty")
|
||||
}
|
||||
|
||||
// Validate allowed methods
|
||||
validMethods := map[string]bool{
|
||||
"GET": true,
|
||||
"PUT": true,
|
||||
"POST": true,
|
||||
"DELETE": true,
|
||||
"HEAD": true,
|
||||
}
|
||||
|
||||
for _, method := range rule.AllowedMethods {
|
||||
if !validMethods[method] {
|
||||
return fmt.Errorf("invalid HTTP method: %s", method)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate origins
|
||||
for _, origin := range rule.AllowedOrigins {
|
||||
if origin == "*" {
|
||||
continue
|
||||
}
|
||||
if err := validateOrigin(origin); err != nil {
|
||||
return fmt.Errorf("invalid origin %s: %v", origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate MaxAgeSeconds
|
||||
if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 {
|
||||
return fmt.Errorf("MaxAgeSeconds cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOrigin validates an origin string
|
||||
func validateOrigin(origin string) error {
|
||||
if origin == "" {
|
||||
return fmt.Errorf("origin cannot be empty")
|
||||
}
|
||||
|
||||
// Special case: "*" is always valid
|
||||
if origin == "*" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count wildcards
|
||||
wildcardCount := strings.Count(origin, "*")
|
||||
if wildcardCount > 1 {
|
||||
return fmt.Errorf("origin can contain at most one wildcard")
|
||||
}
|
||||
|
||||
// If there's a wildcard, it should be in a valid position
|
||||
if wildcardCount == 1 {
|
||||
// Must be in the format: http://*.example.com or https://*.example.com
|
||||
if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
|
||||
return fmt.Errorf("origin with wildcard must start with http:// or https://")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseRequest parses an HTTP request to extract CORS information
|
||||
func ParseRequest(r *http.Request) *CORSRequest {
|
||||
corsReq := &CORSRequest{
|
||||
Origin: r.Header.Get("Origin"),
|
||||
Method: r.Method,
|
||||
}
|
||||
|
||||
// Check if this is a preflight request
|
||||
if r.Method == "OPTIONS" {
|
||||
corsReq.IsPreflightRequest = true
|
||||
corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method")
|
||||
|
||||
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||
corsReq.AccessControlRequestHeaders = strings.Split(headers, ",")
|
||||
for i := range corsReq.AccessControlRequestHeaders {
|
||||
corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return corsReq
|
||||
}
|
||||
|
||||
// EvaluateRequest evaluates a CORS request against a CORS configuration
|
||||
func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) {
|
||||
if config == nil || corsReq == nil {
|
||||
return nil, fmt.Errorf("config and corsReq cannot be nil")
|
||||
}
|
||||
|
||||
if corsReq.Origin == "" {
|
||||
return nil, fmt.Errorf("origin header is required for CORS requests")
|
||||
}
|
||||
|
||||
// Find the first rule that matches the origin
|
||||
for _, rule := range config.CORSRules {
|
||||
if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
|
||||
// For preflight requests, we need more detailed validation
|
||||
if corsReq.IsPreflightRequest {
|
||||
return buildPreflightResponse(&rule, corsReq), nil
|
||||
} else {
|
||||
// For actual requests, check method
|
||||
if contains(rule.AllowedMethods, corsReq.Method) {
|
||||
return buildResponse(&rule, corsReq), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no matching CORS rule found")
|
||||
}
|
||||
|
||||
// matchesRule checks if a CORS request matches a CORS rule
|
||||
func matchesRule(rule *CORSRule, corsReq *CORSRequest) bool {
|
||||
// Check origin - this is the primary matching criterion
|
||||
if !matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
|
||||
return false
|
||||
}
|
||||
|
||||
// For preflight requests, we need to validate both the requested method and headers
|
||||
if corsReq.IsPreflightRequest {
|
||||
// Check if the requested method is allowed
|
||||
if corsReq.AccessControlRequestMethod != "" {
|
||||
if !contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all requested headers are allowed
|
||||
if len(corsReq.AccessControlRequestHeaders) > 0 {
|
||||
for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
|
||||
if !matchesHeader(rule.AllowedHeaders, requestedHeader) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// For non-preflight requests, check method matching
|
||||
method := corsReq.Method
|
||||
if !contains(rule.AllowedMethods, method) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// matchesOrigin checks if an origin matches any of the allowed origins
|
||||
func matchesOrigin(allowedOrigins []string, origin string) bool {
|
||||
for _, allowedOrigin := range allowedOrigins {
|
||||
if allowedOrigin == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
if allowedOrigin == origin {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check wildcard matching
|
||||
if strings.Contains(allowedOrigin, "*") {
|
||||
if matchesWildcard(allowedOrigin, origin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesWildcard checks if an origin matches a wildcard pattern
|
||||
// Uses string manipulation instead of regex for better performance
|
||||
func matchesWildcard(pattern, origin string) bool {
|
||||
// Handle simple cases first
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if pattern == origin {
|
||||
return true
|
||||
}
|
||||
|
||||
// For CORS, we typically only deal with * wildcards (not ? wildcards)
|
||||
// Use string manipulation for * wildcards only (more efficient than regex)
|
||||
|
||||
// Split pattern by wildcards
|
||||
parts := strings.Split(pattern, "*")
|
||||
if len(parts) == 1 {
|
||||
// No wildcards, exact match
|
||||
return pattern == origin
|
||||
}
|
||||
|
||||
// Check if string starts with first part
|
||||
if len(parts[0]) > 0 && !strings.HasPrefix(origin, parts[0]) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if string ends with last part
|
||||
if len(parts[len(parts)-1]) > 0 && !strings.HasSuffix(origin, parts[len(parts)-1]) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check middle parts
|
||||
searchStr := origin
|
||||
if len(parts[0]) > 0 {
|
||||
searchStr = searchStr[len(parts[0]):]
|
||||
}
|
||||
if len(parts[len(parts)-1]) > 0 {
|
||||
searchStr = searchStr[:len(searchStr)-len(parts[len(parts)-1])]
|
||||
}
|
||||
|
||||
for i := 1; i < len(parts)-1; i++ {
|
||||
if len(parts[i]) > 0 {
|
||||
index := strings.Index(searchStr, parts[i])
|
||||
if index == -1 {
|
||||
return false
|
||||
}
|
||||
searchStr = searchStr[index+len(parts[i]):]
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// matchesHeader checks if a header matches allowed headers
|
||||
func matchesHeader(allowedHeaders []string, header string) bool {
|
||||
if len(allowedHeaders) == 0 {
|
||||
return true // No restrictions
|
||||
}
|
||||
|
||||
for _, allowedHeader := range allowedHeaders {
|
||||
if allowedHeader == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.EqualFold(allowedHeader, header) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check wildcard matching for headers
|
||||
if strings.Contains(allowedHeader, "*") {
|
||||
if matchesWildcard(strings.ToLower(allowedHeader), strings.ToLower(header)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// buildPreflightResponse builds a CORS response for preflight requests
|
||||
// This function allows partial matches - origin can match while methods/headers may not
|
||||
func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
|
||||
response := &CORSResponse{
|
||||
AllowOrigin: corsReq.Origin,
|
||||
}
|
||||
|
||||
// Check if the requested method is allowed
|
||||
methodAllowed := corsReq.AccessControlRequestMethod == "" || contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod)
|
||||
|
||||
// Check requested headers
|
||||
var allowedRequestHeaders []string
|
||||
allHeadersAllowed := true
|
||||
|
||||
if len(corsReq.AccessControlRequestHeaders) > 0 {
|
||||
// Check if wildcard is allowed
|
||||
hasWildcard := false
|
||||
for _, header := range rule.AllowedHeaders {
|
||||
if header == "*" {
|
||||
hasWildcard = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasWildcard {
|
||||
// All requested headers are allowed with wildcard
|
||||
allowedRequestHeaders = corsReq.AccessControlRequestHeaders
|
||||
} else {
|
||||
// Check each requested header individually
|
||||
for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
|
||||
if matchesHeader(rule.AllowedHeaders, requestedHeader) {
|
||||
allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader)
|
||||
} else {
|
||||
allHeadersAllowed = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only set method and header info if both method and ALL headers are allowed
|
||||
if methodAllowed && allHeadersAllowed {
|
||||
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
|
||||
|
||||
if len(allowedRequestHeaders) > 0 {
|
||||
response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ")
|
||||
}
|
||||
|
||||
// Set exposed headers
|
||||
if len(rule.ExposeHeaders) > 0 {
|
||||
response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
|
||||
}
|
||||
|
||||
// Set max age
|
||||
if rule.MaxAgeSeconds != nil {
|
||||
response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// buildResponse builds a CORS response from a matching rule
|
||||
func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
|
||||
response := &CORSResponse{
|
||||
AllowOrigin: corsReq.Origin,
|
||||
}
|
||||
|
||||
// Set allowed methods - for preflight requests, return all allowed methods
|
||||
if corsReq.IsPreflightRequest {
|
||||
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
|
||||
} else {
|
||||
// For non-preflight requests, return all allowed methods
|
||||
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
|
||||
}
|
||||
|
||||
// Set allowed headers
|
||||
if corsReq.IsPreflightRequest && len(rule.AllowedHeaders) > 0 {
|
||||
// For preflight requests, check if wildcard is allowed
|
||||
hasWildcard := false
|
||||
for _, header := range rule.AllowedHeaders {
|
||||
if header == "*" {
|
||||
hasWildcard = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasWildcard && len(corsReq.AccessControlRequestHeaders) > 0 {
|
||||
// Return the specific headers that were requested when wildcard is allowed
|
||||
response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
|
||||
} else if len(corsReq.AccessControlRequestHeaders) > 0 {
|
||||
// For non-wildcard cases, return the requested headers (preserving case)
|
||||
// since we already validated they are allowed in matchesRule
|
||||
response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
|
||||
} else {
|
||||
// Fallback to configured headers if no specific headers were requested
|
||||
response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
|
||||
}
|
||||
} else if len(rule.AllowedHeaders) > 0 {
|
||||
// For non-preflight requests, return the allowed headers from the rule
|
||||
response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
|
||||
}
|
||||
|
||||
// Set exposed headers
|
||||
if len(rule.ExposeHeaders) > 0 {
|
||||
response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
|
||||
}
|
||||
|
||||
// Set max age
|
||||
if rule.MaxAgeSeconds != nil {
|
||||
response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// contains checks if a slice contains a string
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ApplyHeaders applies CORS headers to an HTTP response
|
||||
func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) {
|
||||
if corsResp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if corsResp.AllowOrigin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin)
|
||||
}
|
||||
|
||||
if corsResp.AllowMethods != "" {
|
||||
w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods)
|
||||
}
|
||||
|
||||
if corsResp.AllowHeaders != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders)
|
||||
}
|
||||
|
||||
if corsResp.ExposeHeaders != "" {
|
||||
w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders)
|
||||
}
|
||||
|
||||
if corsResp.MaxAge != "" {
|
||||
w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge)
|
||||
}
|
||||
|
||||
if corsResp.AllowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
// FilerClient interface for dependency injection
|
||||
type FilerClient interface {
|
||||
WithFilerClient(streamingMode bool, fn func(filer_pb.SeaweedFilerClient) error) error
|
||||
}
|
||||
|
||||
// EntryGetter interface for getting filer entries
|
||||
type EntryGetter interface {
|
||||
GetEntry(directory, name string) (*filer_pb.Entry, error)
|
||||
}
|
||||
|
||||
// Storage provides CORS configuration storage operations
|
||||
type Storage struct {
|
||||
filerClient FilerClient
|
||||
entryGetter EntryGetter
|
||||
bucketsPath string
|
||||
}
|
||||
|
||||
// NewStorage creates a new CORS storage instance
|
||||
func NewStorage(filerClient FilerClient, entryGetter EntryGetter, bucketsPath string) *Storage {
|
||||
return &Storage{
|
||||
filerClient: filerClient,
|
||||
entryGetter: entryGetter,
|
||||
bucketsPath: bucketsPath,
|
||||
}
|
||||
}
|
||||
|
||||
// Store stores CORS configuration in the filer
|
||||
func (s *Storage) Store(bucket string, config *CORSConfiguration) error {
|
||||
// Store in bucket metadata
|
||||
bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
|
||||
|
||||
// Get existing metadata
|
||||
existingEntry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
|
||||
var metadata map[string]interface{}
|
||||
|
||||
if err == nil && existingEntry != nil && len(existingEntry.Content) > 0 {
|
||||
if err := json.Unmarshal(existingEntry.Content, &metadata); err != nil {
|
||||
glog.V(1).Infof("Failed to unmarshal existing metadata: %v", err)
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
} else {
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
|
||||
metadata["cors"] = config
|
||||
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal bucket metadata: %v", err)
|
||||
}
|
||||
|
||||
// Store metadata
|
||||
return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
|
||||
request := &filer_pb.CreateEntryRequest{
|
||||
Directory: s.bucketsPath + "/" + bucket,
|
||||
Entry: &filer_pb.Entry{
|
||||
Name: S3MetadataFileName,
|
||||
IsDirectory: false,
|
||||
Attributes: &filer_pb.FuseAttributes{
|
||||
Crtime: time.Now().Unix(),
|
||||
Mtime: time.Now().Unix(),
|
||||
FileMode: 0644,
|
||||
},
|
||||
Content: metadataBytes,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := client.CreateEntry(context.Background(), request)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Load loads CORS configuration from the filer
|
||||
func (s *Storage) Load(bucket string) (*CORSConfiguration, error) {
|
||||
bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
|
||||
|
||||
entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
|
||||
if err != nil || entry == nil {
|
||||
return nil, fmt.Errorf("no CORS configuration found")
|
||||
}
|
||||
|
||||
if len(entry.Content) == 0 {
|
||||
return nil, fmt.Errorf("no CORS configuration found")
|
||||
}
|
||||
|
||||
var metadata map[string]interface{}
|
||||
if err := json.Unmarshal(entry.Content, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
|
||||
corsData, exists := metadata["cors"]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no CORS configuration found")
|
||||
}
|
||||
|
||||
// Convert back to CORSConfiguration
|
||||
corsBytes, err := json.Marshal(corsData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal CORS data: %v", err)
|
||||
}
|
||||
|
||||
var config CORSConfiguration
|
||||
if err := json.Unmarshal(corsBytes, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// Delete deletes CORS configuration from the filer
|
||||
func (s *Storage) Delete(bucket string) error {
|
||||
bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
|
||||
|
||||
entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
|
||||
if err != nil || entry == nil {
|
||||
return nil // Already deleted or doesn't exist
|
||||
}
|
||||
|
||||
var metadata map[string]interface{}
|
||||
if len(entry.Content) > 0 {
|
||||
if err := json.Unmarshal(entry.Content, &metadata); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
} else {
|
||||
return nil // No metadata to delete
|
||||
}
|
||||
|
||||
// Remove CORS configuration
|
||||
delete(metadata, "cors")
|
||||
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata: %v", err)
|
||||
}
|
||||
|
||||
// Update metadata
|
||||
return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
|
||||
request := &filer_pb.CreateEntryRequest{
|
||||
Directory: s.bucketsPath + "/" + bucket,
|
||||
Entry: &filer_pb.Entry{
|
||||
Name: S3MetadataFileName,
|
||||
IsDirectory: false,
|
||||
Attributes: &filer_pb.FuseAttributes{
|
||||
Crtime: time.Now().Unix(),
|
||||
Mtime: time.Now().Unix(),
|
||||
FileMode: 0644,
|
||||
},
|
||||
Content: metadataBytes,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := client.CreateEntry(context.Background(), request)
|
||||
return err
|
||||
})
|
||||
}
|
||||
526
weed/s3api/cors/cors_test.go
Normal file
526
weed/s3api/cors/cors_test.go
Normal file
@@ -0,0 +1,526 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateConfiguration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *CORSConfiguration
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty rules",
|
||||
config: &CORSConfiguration{
|
||||
CORSRules: []CORSRule{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid single rule",
|
||||
config: &CORSConfiguration{
|
||||
CORSRules: []CORSRule{
|
||||
{
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedOrigins: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too many rules",
|
||||
config: &CORSConfiguration{
|
||||
CORSRules: make([]CORSRule, 101),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
config: &CORSConfiguration{
|
||||
CORSRules: []CORSRule{
|
||||
{
|
||||
AllowedMethods: []string{"INVALID"},
|
||||
AllowedOrigins: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty origins",
|
||||
config: &CORSConfiguration{
|
||||
CORSRules: []CORSRule{
|
||||
{
|
||||
AllowedMethods: []string{"GET"},
|
||||
AllowedOrigins: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid origin with multiple wildcards",
|
||||
config: &CORSConfiguration{
|
||||
CORSRules: []CORSRule{
|
||||
{
|
||||
AllowedMethods: []string{"GET"},
|
||||
AllowedOrigins: []string{"http://*.*.example.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative MaxAgeSeconds",
|
||||
config: &CORSConfiguration{
|
||||
CORSRules: []CORSRule{
|
||||
{
|
||||
AllowedMethods: []string{"GET"},
|
||||
AllowedOrigins: []string{"*"},
|
||||
MaxAgeSeconds: intPtr(-1),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateConfiguration(tt.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateConfiguration() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOrigin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty origin",
|
||||
origin: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid origin",
|
||||
origin: "http://example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard origin",
|
||||
origin: "*",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid wildcard origin",
|
||||
origin: "http://*.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "https wildcard origin",
|
||||
origin: "https://*.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid wildcard origin",
|
||||
origin: "*.example.com",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple wildcards",
|
||||
origin: "http://*.*.example.com",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateOrigin(tt.origin)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateOrigin() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
want *CORSRequest
|
||||
}{
|
||||
{
|
||||
name: "simple GET request",
|
||||
req: &http.Request{
|
||||
Method: "GET",
|
||||
Header: http.Header{
|
||||
"Origin": []string{"http://example.com"},
|
||||
},
|
||||
},
|
||||
want: &CORSRequest{
|
||||
Origin: "http://example.com",
|
||||
Method: "GET",
|
||||
IsPreflightRequest: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OPTIONS preflight request",
|
||||
req: &http.Request{
|
||||
Method: "OPTIONS",
|
||||
Header: http.Header{
|
||||
"Origin": []string{"http://example.com"},
|
||||
"Access-Control-Request-Method": []string{"PUT"},
|
||||
"Access-Control-Request-Headers": []string{"Content-Type, Authorization"},
|
||||
},
|
||||
},
|
||||
want: &CORSRequest{
|
||||
Origin: "http://example.com",
|
||||
Method: "OPTIONS",
|
||||
IsPreflightRequest: true,
|
||||
AccessControlRequestMethod: "PUT",
|
||||
AccessControlRequestHeaders: []string{"Content-Type", "Authorization"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request without origin",
|
||||
req: &http.Request{
|
||||
Method: "GET",
|
||||
Header: http.Header{},
|
||||
},
|
||||
want: &CORSRequest{
|
||||
Origin: "",
|
||||
Method: "GET",
|
||||
IsPreflightRequest: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ParseRequest(tt.req)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesOrigin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedOrigins []string
|
||||
origin string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard match",
|
||||
allowedOrigins: []string{"*"},
|
||||
origin: "http://example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact match",
|
||||
allowedOrigins: []string{"http://example.com"},
|
||||
origin: "http://example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
allowedOrigins: []string{"http://example.com"},
|
||||
origin: "http://other.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard subdomain match",
|
||||
allowedOrigins: []string{"http://*.example.com"},
|
||||
origin: "http://api.example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard subdomain no match",
|
||||
allowedOrigins: []string{"http://*.example.com"},
|
||||
origin: "http://example.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multiple origins with match",
|
||||
allowedOrigins: []string{"http://example.com", "http://other.com"},
|
||||
origin: "http://other.com",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := matchesOrigin(tt.allowedOrigins, tt.origin)
|
||||
if got != tt.want {
|
||||
t.Errorf("matchesOrigin() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedHeaders []string
|
||||
header string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty allowed headers",
|
||||
allowedHeaders: []string{},
|
||||
header: "Content-Type",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
allowedHeaders: []string{"*"},
|
||||
header: "Content-Type",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact match",
|
||||
allowedHeaders: []string{"Content-Type"},
|
||||
header: "Content-Type",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive match",
|
||||
allowedHeaders: []string{"content-type"},
|
||||
header: "Content-Type",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
allowedHeaders: []string{"Authorization"},
|
||||
header: "Content-Type",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard prefix match",
|
||||
allowedHeaders: []string{"x-amz-*"},
|
||||
header: "x-amz-date",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := matchesHeader(tt.allowedHeaders, tt.header)
|
||||
if got != tt.want {
|
||||
t.Errorf("matchesHeader() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluateRequest(t *testing.T) {
|
||||
config := &CORSConfiguration{
|
||||
CORSRules: []CORSRule{
|
||||
{
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
ExposeHeaders: []string{"ETag"},
|
||||
MaxAgeSeconds: intPtr(3600),
|
||||
},
|
||||
{
|
||||
AllowedMethods: []string{"PUT"},
|
||||
AllowedOrigins: []string{"*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *CORSConfiguration
|
||||
corsReq *CORSRequest
|
||||
want *CORSResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "matching first rule",
|
||||
config: config,
|
||||
corsReq: &CORSRequest{
|
||||
Origin: "http://example.com",
|
||||
Method: "GET",
|
||||
},
|
||||
want: &CORSResponse{
|
||||
AllowOrigin: "http://example.com",
|
||||
AllowMethods: "GET, POST",
|
||||
AllowHeaders: "Content-Type",
|
||||
ExposeHeaders: "ETag",
|
||||
MaxAge: "3600",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "matching second rule",
|
||||
config: config,
|
||||
corsReq: &CORSRequest{
|
||||
Origin: "http://other.com",
|
||||
Method: "PUT",
|
||||
},
|
||||
want: &CORSResponse{
|
||||
AllowOrigin: "http://other.com",
|
||||
AllowMethods: "PUT",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no matching rule",
|
||||
config: config,
|
||||
corsReq: &CORSRequest{
|
||||
Origin: "http://forbidden.com",
|
||||
Method: "GET",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "preflight request",
|
||||
config: config,
|
||||
corsReq: &CORSRequest{
|
||||
Origin: "http://example.com",
|
||||
Method: "OPTIONS",
|
||||
IsPreflightRequest: true,
|
||||
AccessControlRequestMethod: "POST",
|
||||
AccessControlRequestHeaders: []string{"Content-Type"},
|
||||
},
|
||||
want: &CORSResponse{
|
||||
AllowOrigin: "http://example.com",
|
||||
AllowMethods: "GET, POST",
|
||||
AllowHeaders: "Content-Type",
|
||||
ExposeHeaders: "ETag",
|
||||
MaxAge: "3600",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "preflight request with forbidden header",
|
||||
config: config,
|
||||
corsReq: &CORSRequest{
|
||||
Origin: "http://example.com",
|
||||
Method: "OPTIONS",
|
||||
IsPreflightRequest: true,
|
||||
AccessControlRequestMethod: "POST",
|
||||
AccessControlRequestHeaders: []string{"Authorization"},
|
||||
},
|
||||
want: &CORSResponse{
|
||||
AllowOrigin: "http://example.com",
|
||||
// No AllowMethods or AllowHeaders because the requested header is forbidden
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "request without origin",
|
||||
config: config,
|
||||
corsReq: &CORSRequest{
|
||||
Origin: "",
|
||||
Method: "GET",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := EvaluateRequest(tt.config, tt.corsReq)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("EvaluateRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("EvaluateRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
corsResp *CORSResponse
|
||||
want map[string]string
|
||||
}{
|
||||
{
|
||||
name: "nil response",
|
||||
corsResp: nil,
|
||||
want: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "complete response",
|
||||
corsResp: &CORSResponse{
|
||||
AllowOrigin: "http://example.com",
|
||||
AllowMethods: "GET, POST",
|
||||
AllowHeaders: "Content-Type",
|
||||
ExposeHeaders: "ETag",
|
||||
MaxAge: "3600",
|
||||
},
|
||||
want: map[string]string{
|
||||
"Access-Control-Allow-Origin": "http://example.com",
|
||||
"Access-Control-Allow-Methods": "GET, POST",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
"Access-Control-Expose-Headers": "ETag",
|
||||
"Access-Control-Max-Age": "3600",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with credentials",
|
||||
corsResp: &CORSResponse{
|
||||
AllowOrigin: "http://example.com",
|
||||
AllowMethods: "GET",
|
||||
AllowCredentials: true,
|
||||
},
|
||||
want: map[string]string{
|
||||
"Access-Control-Allow-Origin": "http://example.com",
|
||||
"Access-Control-Allow-Methods": "GET",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a proper response writer using httptest
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ApplyHeaders(w, tt.corsResp)
|
||||
|
||||
// Extract headers from the response
|
||||
headers := make(map[string]string)
|
||||
for key, values := range w.Header() {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(headers, tt.want) {
|
||||
t.Errorf("ApplyHeaders() headers = %v, want %v", headers, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions and types for testing
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
143
weed/s3api/cors/middleware.go
Normal file
143
weed/s3api/cors/middleware.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
|
||||
)
|
||||
|
||||
// BucketChecker interface for checking bucket existence
|
||||
type BucketChecker interface {
|
||||
CheckBucket(r *http.Request, bucket string) s3err.ErrorCode
|
||||
}
|
||||
|
||||
// CORSConfigGetter interface for getting CORS configuration
|
||||
type CORSConfigGetter interface {
|
||||
GetCORSConfiguration(bucket string) (*CORSConfiguration, s3err.ErrorCode)
|
||||
}
|
||||
|
||||
// Middleware handles CORS evaluation for all S3 API requests
|
||||
type Middleware struct {
|
||||
storage *Storage
|
||||
bucketChecker BucketChecker
|
||||
corsConfigGetter CORSConfigGetter
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new CORS middleware instance
|
||||
func NewMiddleware(storage *Storage, bucketChecker BucketChecker, corsConfigGetter CORSConfigGetter) *Middleware {
|
||||
return &Middleware{
|
||||
storage: storage,
|
||||
bucketChecker: bucketChecker,
|
||||
corsConfigGetter: corsConfigGetter,
|
||||
}
|
||||
}
|
||||
|
||||
// evaluateCORSRequest performs the common CORS request evaluation logic
|
||||
// Returns: (corsResponse, responseWritten, shouldContinue)
|
||||
// - corsResponse: the CORS response if evaluation succeeded
|
||||
// - responseWritten: true if an error response was already written
|
||||
// - shouldContinue: true if the request should continue to the next handler
|
||||
func (m *Middleware) evaluateCORSRequest(w http.ResponseWriter, r *http.Request) (*CORSResponse, bool, bool) {
|
||||
// Parse CORS request
|
||||
corsReq := ParseRequest(r)
|
||||
if corsReq.Origin == "" {
|
||||
// Not a CORS request
|
||||
return nil, false, true
|
||||
}
|
||||
|
||||
// Extract bucket from request
|
||||
bucket, _ := s3_constants.GetBucketAndObject(r)
|
||||
if bucket == "" {
|
||||
return nil, false, true
|
||||
}
|
||||
|
||||
// Check if bucket exists
|
||||
if err := m.bucketChecker.CheckBucket(r, bucket); err != s3err.ErrNone {
|
||||
// For non-existent buckets, let the normal handler deal with it
|
||||
return nil, false, true
|
||||
}
|
||||
|
||||
// Load CORS configuration from cache
|
||||
config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket)
|
||||
if errCode != s3err.ErrNone || config == nil {
|
||||
// No CORS configuration, handle based on request type
|
||||
if corsReq.IsPreflightRequest {
|
||||
// Preflight request without CORS config should fail
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
|
||||
return nil, true, false // Response written, don't continue
|
||||
}
|
||||
// Non-preflight request, continue normally
|
||||
return nil, false, true
|
||||
}
|
||||
|
||||
// Evaluate CORS request
|
||||
corsResp, err := EvaluateRequest(config, corsReq)
|
||||
if err != nil {
|
||||
glog.V(3).Infof("CORS evaluation failed for bucket %s: %v", bucket, err)
|
||||
if corsReq.IsPreflightRequest {
|
||||
// Preflight request that doesn't match CORS rules should fail
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
|
||||
return nil, true, false // Response written, don't continue
|
||||
}
|
||||
// Non-preflight request, continue normally but without CORS headers
|
||||
return nil, false, true
|
||||
}
|
||||
|
||||
return corsResp, false, false
|
||||
}
|
||||
|
||||
// Handler returns the CORS middleware handler
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Use the common evaluation logic
|
||||
corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r)
|
||||
if responseWritten {
|
||||
// Response was already written (error case)
|
||||
return
|
||||
}
|
||||
|
||||
if shouldContinue {
|
||||
// Continue with normal request processing
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request to check if it's a preflight request
|
||||
corsReq := ParseRequest(r)
|
||||
|
||||
// Apply CORS headers to response
|
||||
ApplyHeaders(w, corsResp)
|
||||
|
||||
// Handle preflight requests
|
||||
if corsReq.IsPreflightRequest {
|
||||
// Preflight request should return 200 OK with just CORS headers
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Continue with normal request processing
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// HandleOptionsRequest handles OPTIONS requests for CORS preflight
|
||||
func (m *Middleware) HandleOptionsRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Use the common evaluation logic
|
||||
corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r)
|
||||
if responseWritten {
|
||||
// Response was already written (error case)
|
||||
return
|
||||
}
|
||||
|
||||
if shouldContinue || corsResp == nil {
|
||||
// Not a CORS request or should continue normally
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply CORS headers and return success
|
||||
ApplyHeaders(w, corsResp)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
package s3api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/cors"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
|
||||
)
|
||||
@@ -18,6 +22,7 @@ type BucketConfig struct {
|
||||
Ownership string
|
||||
ACL []byte
|
||||
Owner string
|
||||
CORS *cors.CORSConfiguration
|
||||
LastModified time.Time
|
||||
Entry *filer_pb.Entry
|
||||
}
|
||||
@@ -118,6 +123,19 @@ func (s3a *S3ApiServer) getBucketConfig(bucket string) (*BucketConfig, s3err.Err
|
||||
}
|
||||
}
|
||||
|
||||
// Load CORS configuration from .s3metadata
|
||||
if corsConfig, err := s3a.loadCORSFromMetadata(bucket); err != nil {
|
||||
if err == filer_pb.ErrNotFound {
|
||||
// Missing metadata is not an error; fall back cleanly
|
||||
glog.V(2).Infof("CORS metadata not found for bucket %s, falling back to default behavior", bucket)
|
||||
} else {
|
||||
// Log parsing or validation errors
|
||||
glog.Errorf("Failed to load CORS configuration for bucket %s: %v", bucket, err)
|
||||
}
|
||||
} else {
|
||||
config.CORS = corsConfig
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
s3a.bucketConfigCache.Set(bucket, config)
|
||||
|
||||
@@ -244,3 +262,114 @@ func (s3a *S3ApiServer) removeBucketConfigKey(bucket, key string) s3err.ErrorCod
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// loadCORSFromMetadata loads CORS configuration from bucket metadata
|
||||
func (s3a *S3ApiServer) loadCORSFromMetadata(bucket string) (*cors.CORSConfiguration, error) {
|
||||
// Validate bucket name to prevent path traversal attacks
|
||||
if bucket == "" || strings.Contains(bucket, "/") || strings.Contains(bucket, "\\") ||
|
||||
strings.Contains(bucket, "..") || strings.Contains(bucket, "~") {
|
||||
return nil, fmt.Errorf("invalid bucket name: %s", bucket)
|
||||
}
|
||||
|
||||
// Clean the bucket name further to prevent any potential path traversal
|
||||
bucket = filepath.Clean(bucket)
|
||||
if bucket == "." || bucket == ".." {
|
||||
return nil, fmt.Errorf("invalid bucket name: %s", bucket)
|
||||
}
|
||||
|
||||
bucketMetadataPath := filepath.Join(s3a.option.BucketsPath, bucket, cors.S3MetadataFileName)
|
||||
|
||||
entry, err := s3a.getEntry("", bucketMetadataPath)
|
||||
if err != nil {
|
||||
glog.V(3).Infof("loadCORSFromMetadata: error retrieving metadata for bucket %s: %v", bucket, err)
|
||||
return nil, fmt.Errorf("error retrieving metadata for bucket %s: %v", bucket, err)
|
||||
}
|
||||
if entry == nil {
|
||||
glog.V(3).Infof("loadCORSFromMetadata: no metadata entry found for bucket %s", bucket)
|
||||
return nil, fmt.Errorf("no metadata entry found for bucket %s", bucket)
|
||||
}
|
||||
|
||||
if len(entry.Content) == 0 {
|
||||
glog.V(3).Infof("loadCORSFromMetadata: empty metadata content for bucket %s", bucket)
|
||||
return nil, fmt.Errorf("no metadata content for bucket %s", bucket)
|
||||
}
|
||||
|
||||
var metadata map[string]json.RawMessage
|
||||
if err := json.Unmarshal(entry.Content, &metadata); err != nil {
|
||||
glog.Errorf("loadCORSFromMetadata: failed to unmarshal metadata for bucket %s: %v", bucket, err)
|
||||
return nil, fmt.Errorf("failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
|
||||
corsData, exists := metadata["cors"]
|
||||
if !exists {
|
||||
glog.V(3).Infof("loadCORSFromMetadata: no CORS configuration found for bucket %s", bucket)
|
||||
return nil, fmt.Errorf("no CORS configuration found")
|
||||
}
|
||||
|
||||
// Directly unmarshal the raw JSON to CORSConfiguration to avoid round-trip allocations
|
||||
var config cors.CORSConfiguration
|
||||
if err := json.Unmarshal(corsData, &config); err != nil {
|
||||
glog.Errorf("loadCORSFromMetadata: failed to unmarshal CORS configuration for bucket %s: %v", bucket, err)
|
||||
return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// getCORSConfiguration retrieves CORS configuration with caching
|
||||
func (s3a *S3ApiServer) getCORSConfiguration(bucket string) (*cors.CORSConfiguration, s3err.ErrorCode) {
|
||||
config, errCode := s3a.getBucketConfig(bucket)
|
||||
if errCode != s3err.ErrNone {
|
||||
return nil, errCode
|
||||
}
|
||||
|
||||
return config.CORS, s3err.ErrNone
|
||||
}
|
||||
|
||||
// getCORSStorage returns a CORS storage instance for persistent operations
|
||||
func (s3a *S3ApiServer) getCORSStorage() *cors.Storage {
|
||||
entryGetter := &S3EntryGetter{server: s3a}
|
||||
return cors.NewStorage(s3a, entryGetter, s3a.option.BucketsPath)
|
||||
}
|
||||
|
||||
// updateCORSConfiguration updates CORS configuration and invalidates cache
|
||||
func (s3a *S3ApiServer) updateCORSConfiguration(bucket string, corsConfig *cors.CORSConfiguration) s3err.ErrorCode {
|
||||
// Update in-memory cache
|
||||
errCode := s3a.updateBucketConfig(bucket, func(config *BucketConfig) error {
|
||||
config.CORS = corsConfig
|
||||
return nil
|
||||
})
|
||||
if errCode != s3err.ErrNone {
|
||||
return errCode
|
||||
}
|
||||
|
||||
// Persist to .s3metadata file
|
||||
storage := s3a.getCORSStorage()
|
||||
if err := storage.Store(bucket, corsConfig); err != nil {
|
||||
glog.Errorf("updateCORSConfiguration: failed to persist CORS config to metadata for bucket %s: %v", bucket, err)
|
||||
return s3err.ErrInternalError
|
||||
}
|
||||
|
||||
return s3err.ErrNone
|
||||
}
|
||||
|
||||
// removeCORSConfiguration removes CORS configuration and invalidates cache
|
||||
func (s3a *S3ApiServer) removeCORSConfiguration(bucket string) s3err.ErrorCode {
|
||||
// Remove from in-memory cache
|
||||
errCode := s3a.updateBucketConfig(bucket, func(config *BucketConfig) error {
|
||||
config.CORS = nil
|
||||
return nil
|
||||
})
|
||||
if errCode != s3err.ErrNone {
|
||||
return errCode
|
||||
}
|
||||
|
||||
// Remove from .s3metadata file
|
||||
storage := s3a.getCORSStorage()
|
||||
if err := storage.Delete(bucket); err != nil {
|
||||
glog.Errorf("removeCORSConfiguration: failed to remove CORS config from metadata for bucket %s: %v", bucket, err)
|
||||
return s3err.ErrInternalError
|
||||
}
|
||||
|
||||
return s3err.ErrNone
|
||||
}
|
||||
|
||||
140
weed/s3api/s3api_bucket_cors_handlers.go
Normal file
140
weed/s3api/s3api_bucket_cors_handlers.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package s3api
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"net/http"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/cors"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
|
||||
)
|
||||
|
||||
// S3EntryGetter implements cors.EntryGetter interface
|
||||
type S3EntryGetter struct {
|
||||
server *S3ApiServer
|
||||
}
|
||||
|
||||
func (g *S3EntryGetter) GetEntry(directory, name string) (*filer_pb.Entry, error) {
|
||||
return g.server.getEntry(directory, name)
|
||||
}
|
||||
|
||||
// S3BucketChecker implements cors.BucketChecker interface
|
||||
type S3BucketChecker struct {
|
||||
server *S3ApiServer
|
||||
}
|
||||
|
||||
func (c *S3BucketChecker) CheckBucket(r *http.Request, bucket string) s3err.ErrorCode {
|
||||
return c.server.checkBucket(r, bucket)
|
||||
}
|
||||
|
||||
// S3CORSConfigGetter implements cors.CORSConfigGetter interface
|
||||
type S3CORSConfigGetter struct {
|
||||
server *S3ApiServer
|
||||
}
|
||||
|
||||
func (g *S3CORSConfigGetter) GetCORSConfiguration(bucket string) (*cors.CORSConfiguration, s3err.ErrorCode) {
|
||||
return g.server.getCORSConfiguration(bucket)
|
||||
}
|
||||
|
||||
// getCORSMiddleware returns a CORS middleware instance with caching
|
||||
func (s3a *S3ApiServer) getCORSMiddleware() *cors.Middleware {
|
||||
storage := s3a.getCORSStorage()
|
||||
bucketChecker := &S3BucketChecker{server: s3a}
|
||||
corsConfigGetter := &S3CORSConfigGetter{server: s3a}
|
||||
|
||||
return cors.NewMiddleware(storage, bucketChecker, corsConfigGetter)
|
||||
}
|
||||
|
||||
// GetBucketCorsHandler handles Get bucket CORS configuration
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketCors.html
|
||||
func (s3a *S3ApiServer) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
bucket, _ := s3_constants.GetBucketAndObject(r)
|
||||
glog.V(3).Infof("GetBucketCorsHandler %s", bucket)
|
||||
|
||||
if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone {
|
||||
s3err.WriteErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Load CORS configuration from cache
|
||||
config, errCode := s3a.getCORSConfiguration(bucket)
|
||||
if errCode != s3err.ErrNone {
|
||||
if errCode == s3err.ErrNoSuchBucket {
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchBucket)
|
||||
} else {
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrInternalError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchCORSConfiguration)
|
||||
return
|
||||
}
|
||||
|
||||
// Return CORS configuration as XML
|
||||
writeSuccessResponseXML(w, r, config)
|
||||
}
|
||||
|
||||
// PutBucketCorsHandler handles Put bucket CORS configuration
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketCors.html
|
||||
func (s3a *S3ApiServer) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
bucket, _ := s3_constants.GetBucketAndObject(r)
|
||||
glog.V(3).Infof("PutBucketCorsHandler %s", bucket)
|
||||
|
||||
if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone {
|
||||
s3err.WriteErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse CORS configuration from request body
|
||||
var config cors.CORSConfiguration
|
||||
if err := xml.NewDecoder(r.Body).Decode(&config); err != nil {
|
||||
glog.V(1).Infof("Failed to parse CORS configuration: %v", err)
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrMalformedXML)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CORS configuration
|
||||
if err := cors.ValidateConfiguration(&config); err != nil {
|
||||
glog.V(1).Infof("Invalid CORS configuration: %v", err)
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store CORS configuration and update cache
|
||||
// This handles both cache update and persistent storage through the unified bucket config system
|
||||
if err := s3a.updateCORSConfiguration(bucket, &config); err != s3err.ErrNone {
|
||||
glog.Errorf("Failed to update CORS configuration: %v", err)
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return success
|
||||
writeSuccessResponseEmpty(w, r)
|
||||
}
|
||||
|
||||
// DeleteBucketCorsHandler handles Delete bucket CORS configuration
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteBucketCors.html
|
||||
func (s3a *S3ApiServer) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
bucket, _ := s3_constants.GetBucketAndObject(r)
|
||||
glog.V(3).Infof("DeleteBucketCorsHandler %s", bucket)
|
||||
|
||||
if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone {
|
||||
s3err.WriteErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove CORS configuration from cache and persistent storage
|
||||
// This handles both cache invalidation and persistent storage cleanup through the unified bucket config system
|
||||
if err := s3a.removeCORSConfiguration(bucket); err != s3err.ErrNone {
|
||||
glog.Errorf("Failed to remove CORS configuration: %v", err)
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return success (204 No Content)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -8,24 +8,6 @@ import (
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
|
||||
)
|
||||
|
||||
// GetBucketCorsHandler Get bucket CORS
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketCors.html
|
||||
func (s3a *S3ApiServer) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchCORSConfiguration)
|
||||
}
|
||||
|
||||
// PutBucketCorsHandler Put bucket CORS
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketCors.html
|
||||
func (s3a *S3ApiServer) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s3err.WriteErrorResponse(w, r, s3err.ErrNotImplemented)
|
||||
}
|
||||
|
||||
// DeleteBucketCorsHandler Delete bucket CORS
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteBucketCors.html
|
||||
func (s3a *S3ApiServer) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s3err.WriteErrorResponse(w, r, http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetBucketPolicyHandler Get bucket Policy
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketPolicy.html
|
||||
func (s3a *S3ApiServer) GetBucketPolicyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -20,6 +20,17 @@ import (
|
||||
util_http "github.com/seaweedfs/seaweedfs/weed/util/http"
|
||||
)
|
||||
|
||||
// corsHeaders defines the CORS headers that need to be preserved
|
||||
// Package-level constant to avoid repeated allocations
|
||||
var corsHeaders = []string{
|
||||
"Access-Control-Allow-Origin",
|
||||
"Access-Control-Allow-Methods",
|
||||
"Access-Control-Allow-Headers",
|
||||
"Access-Control-Expose-Headers",
|
||||
"Access-Control-Max-Age",
|
||||
"Access-Control-Allow-Credentials",
|
||||
}
|
||||
|
||||
func mimeDetect(r *http.Request, dataReader io.Reader) io.ReadCloser {
|
||||
mimeBuffer := make([]byte, 512)
|
||||
size, _ := dataReader.Read(mimeBuffer)
|
||||
@@ -381,10 +392,34 @@ func setUserMetadataKeyToLowercase(resp *http.Response) {
|
||||
}
|
||||
}
|
||||
|
||||
func captureCORSHeaders(w http.ResponseWriter, headersToCapture []string) map[string]string {
|
||||
captured := make(map[string]string)
|
||||
for _, corsHeader := range headersToCapture {
|
||||
if value := w.Header().Get(corsHeader); value != "" {
|
||||
captured[corsHeader] = value
|
||||
}
|
||||
}
|
||||
return captured
|
||||
}
|
||||
|
||||
func restoreCORSHeaders(w http.ResponseWriter, capturedCORSHeaders map[string]string) {
|
||||
for corsHeader, value := range capturedCORSHeaders {
|
||||
w.Header().Set(corsHeader, value)
|
||||
}
|
||||
}
|
||||
|
||||
func passThroughResponse(proxyResponse *http.Response, w http.ResponseWriter) (statusCode int, bytesTransferred int64) {
|
||||
// Capture existing CORS headers that may have been set by middleware
|
||||
capturedCORSHeaders := captureCORSHeaders(w, corsHeaders)
|
||||
|
||||
// Copy headers from proxy response
|
||||
for k, v := range proxyResponse.Header {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
|
||||
// Restore CORS headers that were set by middleware
|
||||
restoreCORSHeaders(w, capturedCORSHeaders)
|
||||
|
||||
if proxyResponse.Header.Get("Content-Range") != "" && proxyResponse.StatusCode == 200 {
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
statusCode = http.StatusPartialContent
|
||||
|
||||
@@ -121,6 +121,35 @@ func NewS3ApiServerWithStore(router *mux.Router, option *S3ApiServerOption, expl
|
||||
return s3ApiServer, nil
|
||||
}
|
||||
|
||||
// handleCORSOriginValidation handles the common CORS origin validation logic
|
||||
func (s3a *S3ApiServer) handleCORSOriginValidation(w http.ResponseWriter, r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" {
|
||||
origin = "*"
|
||||
} else {
|
||||
originFound := false
|
||||
for _, allowedOrigin := range s3a.option.AllowedOrigins {
|
||||
if origin == allowedOrigin {
|
||||
originFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !originFound {
|
||||
writeFailureResponse(w, r, http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
return true
|
||||
}
|
||||
|
||||
func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
|
||||
// API Router
|
||||
apiRouter := router.PathPrefix("/").Subrouter()
|
||||
@@ -129,33 +158,6 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
|
||||
apiRouter.Methods(http.MethodGet).Path("/status").HandlerFunc(s3a.StatusHandler)
|
||||
apiRouter.Methods(http.MethodGet).Path("/healthz").HandlerFunc(s3a.StatusHandler)
|
||||
|
||||
apiRouter.Methods(http.MethodOptions).HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" {
|
||||
origin = "*"
|
||||
} else {
|
||||
originFound := false
|
||||
for _, allowedOrigin := range s3a.option.AllowedOrigins {
|
||||
if origin == allowedOrigin {
|
||||
originFound = true
|
||||
}
|
||||
}
|
||||
if !originFound {
|
||||
writeFailureResponse(w, r, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
writeSuccessResponseEmpty(w, r)
|
||||
})
|
||||
|
||||
var routers []*mux.Router
|
||||
if s3a.option.DomainName != "" {
|
||||
domainNames := strings.Split(s3a.option.DomainName, ",")
|
||||
@@ -168,7 +170,16 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
|
||||
}
|
||||
routers = append(routers, apiRouter.PathPrefix("/{bucket}").Subrouter())
|
||||
|
||||
// Get CORS middleware instance with caching
|
||||
corsMiddleware := s3a.getCORSMiddleware()
|
||||
|
||||
for _, bucket := range routers {
|
||||
// Apply CORS middleware to bucket routers for automatic CORS header handling
|
||||
bucket.Use(corsMiddleware.Handler)
|
||||
|
||||
// Bucket-specific OPTIONS handler for CORS preflight requests
|
||||
// Use PathPrefix to catch all bucket-level preflight routes including /bucket/object
|
||||
bucket.PathPrefix("/").Methods(http.MethodOptions).HandlerFunc(corsMiddleware.HandleOptionsRequest)
|
||||
|
||||
// each case should follow the next rule:
|
||||
// - requesting object with query must precede any other methods
|
||||
@@ -330,6 +341,25 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
|
||||
|
||||
}
|
||||
|
||||
// Global OPTIONS handler for service-level requests (non-bucket requests)
|
||||
// This handles requests like OPTIONS /, OPTIONS /status, OPTIONS /healthz
|
||||
// Place this after bucket handlers to avoid interfering with bucket CORS middleware
|
||||
apiRouter.Methods(http.MethodOptions).PathPrefix("/").HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only handle if this is not a bucket-specific request
|
||||
vars := mux.Vars(r)
|
||||
bucket := vars["bucket"]
|
||||
if bucket != "" {
|
||||
// This is a bucket-specific request, let bucket CORS middleware handle it
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if s3a.handleCORSOriginValidation(w, r) {
|
||||
writeSuccessResponseEmpty(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
// ListBuckets
|
||||
apiRouter.Methods(http.MethodGet).Path("/").HandlerFunc(track(s3a.ListBucketsHandler, "LIST"))
|
||||
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
)
|
||||
|
||||
type mimeType string
|
||||
@@ -76,10 +77,25 @@ func EncodeXMLResponse(response interface{}) []byte {
|
||||
func setCommonHeaders(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("x-amz-request-id", fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
|
||||
// Only set static CORS headers for service-level requests, not bucket-specific requests
|
||||
if r.Header.Get("Origin") != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
// Use mux.Vars to detect bucket-specific requests more reliably
|
||||
vars := mux.Vars(r)
|
||||
bucket := vars["bucket"]
|
||||
isBucketRequest := bucket != ""
|
||||
|
||||
// Only apply static CORS headers if this is NOT a bucket-specific request
|
||||
// and no bucket-specific CORS headers were already set
|
||||
if !isBucketRequest && w.Header().Get("Access-Control-Allow-Origin") == "" {
|
||||
// This is a service-level request (like OPTIONS /), apply static CORS
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
// For bucket-specific requests, let the CORS middleware handle the headers
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user