mirror of
https://github.com/seaweedfs/seaweedfs.git
synced 2025-09-19 12:17:59 +08:00

* adding cors support * address some comments * optimize matchesWildcard * address comments * fix for tests * address comments * address comments * address comments * path building * refactor * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * address comment Service-level responses need both Access-Control-Allow-Methods and Access-Control-Allow-Headers. After setting Access-Control-Allow-Origin and Access-Control-Expose-Headers, also set Access-Control-Allow-Methods: * and Access-Control-Allow-Headers: * so service endpoints satisfy CORS preflight requirements. * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix * refactor * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_server.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * simplify * add cors tests * fix tests * fix tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
527 lines
12 KiB
Go
527 lines
12 KiB
Go
package cors
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"testing"
|
|
)
|
|
|
|
func TestValidateConfiguration(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *CORSConfiguration
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "nil config",
|
|
config: nil,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty rules",
|
|
config: &CORSConfiguration{
|
|
CORSRules: []CORSRule{},
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid single rule",
|
|
config: &CORSConfiguration{
|
|
CORSRules: []CORSRule{
|
|
{
|
|
AllowedMethods: []string{"GET", "POST"},
|
|
AllowedOrigins: []string{"*"},
|
|
},
|
|
},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "too many rules",
|
|
config: &CORSConfiguration{
|
|
CORSRules: make([]CORSRule, 101),
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid method",
|
|
config: &CORSConfiguration{
|
|
CORSRules: []CORSRule{
|
|
{
|
|
AllowedMethods: []string{"INVALID"},
|
|
AllowedOrigins: []string{"*"},
|
|
},
|
|
},
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty origins",
|
|
config: &CORSConfiguration{
|
|
CORSRules: []CORSRule{
|
|
{
|
|
AllowedMethods: []string{"GET"},
|
|
AllowedOrigins: []string{},
|
|
},
|
|
},
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid origin with multiple wildcards",
|
|
config: &CORSConfiguration{
|
|
CORSRules: []CORSRule{
|
|
{
|
|
AllowedMethods: []string{"GET"},
|
|
AllowedOrigins: []string{"http://*.*.example.com"},
|
|
},
|
|
},
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "negative MaxAgeSeconds",
|
|
config: &CORSConfiguration{
|
|
CORSRules: []CORSRule{
|
|
{
|
|
AllowedMethods: []string{"GET"},
|
|
AllowedOrigins: []string{"*"},
|
|
MaxAgeSeconds: intPtr(-1),
|
|
},
|
|
},
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := ValidateConfiguration(tt.config)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("ValidateConfiguration() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateOrigin(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
origin string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "empty origin",
|
|
origin: "",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid origin",
|
|
origin: "http://example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "wildcard origin",
|
|
origin: "*",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "valid wildcard origin",
|
|
origin: "http://*.example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "https wildcard origin",
|
|
origin: "https://*.example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid wildcard origin",
|
|
origin: "*.example.com",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "multiple wildcards",
|
|
origin: "http://*.*.example.com",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateOrigin(tt.origin)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("validateOrigin() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *http.Request
|
|
want *CORSRequest
|
|
}{
|
|
{
|
|
name: "simple GET request",
|
|
req: &http.Request{
|
|
Method: "GET",
|
|
Header: http.Header{
|
|
"Origin": []string{"http://example.com"},
|
|
},
|
|
},
|
|
want: &CORSRequest{
|
|
Origin: "http://example.com",
|
|
Method: "GET",
|
|
IsPreflightRequest: false,
|
|
},
|
|
},
|
|
{
|
|
name: "OPTIONS preflight request",
|
|
req: &http.Request{
|
|
Method: "OPTIONS",
|
|
Header: http.Header{
|
|
"Origin": []string{"http://example.com"},
|
|
"Access-Control-Request-Method": []string{"PUT"},
|
|
"Access-Control-Request-Headers": []string{"Content-Type, Authorization"},
|
|
},
|
|
},
|
|
want: &CORSRequest{
|
|
Origin: "http://example.com",
|
|
Method: "OPTIONS",
|
|
IsPreflightRequest: true,
|
|
AccessControlRequestMethod: "PUT",
|
|
AccessControlRequestHeaders: []string{"Content-Type", "Authorization"},
|
|
},
|
|
},
|
|
{
|
|
name: "request without origin",
|
|
req: &http.Request{
|
|
Method: "GET",
|
|
Header: http.Header{},
|
|
},
|
|
want: &CORSRequest{
|
|
Origin: "",
|
|
Method: "GET",
|
|
IsPreflightRequest: false,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := ParseRequest(tt.req)
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("ParseRequest() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchesOrigin(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
allowedOrigins []string
|
|
origin string
|
|
want bool
|
|
}{
|
|
{
|
|
name: "wildcard match",
|
|
allowedOrigins: []string{"*"},
|
|
origin: "http://example.com",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "exact match",
|
|
allowedOrigins: []string{"http://example.com"},
|
|
origin: "http://example.com",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "no match",
|
|
allowedOrigins: []string{"http://example.com"},
|
|
origin: "http://other.com",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "wildcard subdomain match",
|
|
allowedOrigins: []string{"http://*.example.com"},
|
|
origin: "http://api.example.com",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "wildcard subdomain no match",
|
|
allowedOrigins: []string{"http://*.example.com"},
|
|
origin: "http://example.com",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "multiple origins with match",
|
|
allowedOrigins: []string{"http://example.com", "http://other.com"},
|
|
origin: "http://other.com",
|
|
want: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := matchesOrigin(tt.allowedOrigins, tt.origin)
|
|
if got != tt.want {
|
|
t.Errorf("matchesOrigin() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchesHeader(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
allowedHeaders []string
|
|
header string
|
|
want bool
|
|
}{
|
|
{
|
|
name: "empty allowed headers",
|
|
allowedHeaders: []string{},
|
|
header: "Content-Type",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "wildcard match",
|
|
allowedHeaders: []string{"*"},
|
|
header: "Content-Type",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "exact match",
|
|
allowedHeaders: []string{"Content-Type"},
|
|
header: "Content-Type",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "case insensitive match",
|
|
allowedHeaders: []string{"content-type"},
|
|
header: "Content-Type",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "no match",
|
|
allowedHeaders: []string{"Authorization"},
|
|
header: "Content-Type",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "wildcard prefix match",
|
|
allowedHeaders: []string{"x-amz-*"},
|
|
header: "x-amz-date",
|
|
want: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := matchesHeader(tt.allowedHeaders, tt.header)
|
|
if got != tt.want {
|
|
t.Errorf("matchesHeader() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEvaluateRequest(t *testing.T) {
|
|
config := &CORSConfiguration{
|
|
CORSRules: []CORSRule{
|
|
{
|
|
AllowedMethods: []string{"GET", "POST"},
|
|
AllowedOrigins: []string{"http://example.com"},
|
|
AllowedHeaders: []string{"Content-Type"},
|
|
ExposeHeaders: []string{"ETag"},
|
|
MaxAgeSeconds: intPtr(3600),
|
|
},
|
|
{
|
|
AllowedMethods: []string{"PUT"},
|
|
AllowedOrigins: []string{"*"},
|
|
},
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
config *CORSConfiguration
|
|
corsReq *CORSRequest
|
|
want *CORSResponse
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "matching first rule",
|
|
config: config,
|
|
corsReq: &CORSRequest{
|
|
Origin: "http://example.com",
|
|
Method: "GET",
|
|
},
|
|
want: &CORSResponse{
|
|
AllowOrigin: "http://example.com",
|
|
AllowMethods: "GET, POST",
|
|
AllowHeaders: "Content-Type",
|
|
ExposeHeaders: "ETag",
|
|
MaxAge: "3600",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "matching second rule",
|
|
config: config,
|
|
corsReq: &CORSRequest{
|
|
Origin: "http://other.com",
|
|
Method: "PUT",
|
|
},
|
|
want: &CORSResponse{
|
|
AllowOrigin: "http://other.com",
|
|
AllowMethods: "PUT",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "no matching rule",
|
|
config: config,
|
|
corsReq: &CORSRequest{
|
|
Origin: "http://forbidden.com",
|
|
Method: "GET",
|
|
},
|
|
want: nil,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "preflight request",
|
|
config: config,
|
|
corsReq: &CORSRequest{
|
|
Origin: "http://example.com",
|
|
Method: "OPTIONS",
|
|
IsPreflightRequest: true,
|
|
AccessControlRequestMethod: "POST",
|
|
AccessControlRequestHeaders: []string{"Content-Type"},
|
|
},
|
|
want: &CORSResponse{
|
|
AllowOrigin: "http://example.com",
|
|
AllowMethods: "GET, POST",
|
|
AllowHeaders: "Content-Type",
|
|
ExposeHeaders: "ETag",
|
|
MaxAge: "3600",
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "preflight request with forbidden header",
|
|
config: config,
|
|
corsReq: &CORSRequest{
|
|
Origin: "http://example.com",
|
|
Method: "OPTIONS",
|
|
IsPreflightRequest: true,
|
|
AccessControlRequestMethod: "POST",
|
|
AccessControlRequestHeaders: []string{"Authorization"},
|
|
},
|
|
want: &CORSResponse{
|
|
AllowOrigin: "http://example.com",
|
|
// No AllowMethods or AllowHeaders because the requested header is forbidden
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "request without origin",
|
|
config: config,
|
|
corsReq: &CORSRequest{
|
|
Origin: "",
|
|
Method: "GET",
|
|
},
|
|
want: nil,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := EvaluateRequest(tt.config, tt.corsReq)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("EvaluateRequest() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("EvaluateRequest() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestApplyHeaders(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
corsResp *CORSResponse
|
|
want map[string]string
|
|
}{
|
|
{
|
|
name: "nil response",
|
|
corsResp: nil,
|
|
want: map[string]string{},
|
|
},
|
|
{
|
|
name: "complete response",
|
|
corsResp: &CORSResponse{
|
|
AllowOrigin: "http://example.com",
|
|
AllowMethods: "GET, POST",
|
|
AllowHeaders: "Content-Type",
|
|
ExposeHeaders: "ETag",
|
|
MaxAge: "3600",
|
|
},
|
|
want: map[string]string{
|
|
"Access-Control-Allow-Origin": "http://example.com",
|
|
"Access-Control-Allow-Methods": "GET, POST",
|
|
"Access-Control-Allow-Headers": "Content-Type",
|
|
"Access-Control-Expose-Headers": "ETag",
|
|
"Access-Control-Max-Age": "3600",
|
|
},
|
|
},
|
|
{
|
|
name: "with credentials",
|
|
corsResp: &CORSResponse{
|
|
AllowOrigin: "http://example.com",
|
|
AllowMethods: "GET",
|
|
AllowCredentials: true,
|
|
},
|
|
want: map[string]string{
|
|
"Access-Control-Allow-Origin": "http://example.com",
|
|
"Access-Control-Allow-Methods": "GET",
|
|
"Access-Control-Allow-Credentials": "true",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create a proper response writer using httptest
|
|
w := httptest.NewRecorder()
|
|
|
|
ApplyHeaders(w, tt.corsResp)
|
|
|
|
// Extract headers from the response
|
|
headers := make(map[string]string)
|
|
for key, values := range w.Header() {
|
|
if len(values) > 0 {
|
|
headers[key] = values[0]
|
|
}
|
|
}
|
|
|
|
if !reflect.DeepEqual(headers, tt.want) {
|
|
t.Errorf("ApplyHeaders() headers = %v, want %v", headers, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Helper functions and types for testing
|
|
|
|
func intPtr(i int) *int {
|
|
return &i
|
|
}
|