1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16
17 package s3
18
19 import (
20 "bytes"
21 "context"
22 "encoding/xml"
23 "fmt"
24 "io"
25 "log"
26 "net/http"
27 "net/url"
28 "regexp"
29 "strings"
30
31 "github.com/aws/aws-sdk-go/aws"
32 "github.com/aws/aws-sdk-go/aws/awserr"
33 "github.com/aws/aws-sdk-go/aws/client"
34 "github.com/aws/aws-sdk-go/aws/endpoints"
35 "github.com/aws/aws-sdk-go/aws/request"
36 "github.com/aws/aws-sdk-go/service/s3"
37 "github.com/aws/aws-sdk-go/service/s3/s3iface"
38 "github.com/aws/aws-sdk-go/service/s3/s3manager"
39 )
40
41 type bucketInfo struct {
42 endpoint string
43 region string
44 isAWS bool
45 }
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63 func normalizeBucketLocation(ctx context.Context, cfg client.ConfigProvider, endpoint string, bucket string, configRegion string) (bucketInfo, error) {
64 if strings.HasPrefix(endpoint, "https://") || strings.HasPrefix(endpoint, "http://") {
65 return bucketInfo{}, fmt.Errorf("invalid s3 endpoint: must not include uri scheme")
66 }
67
68 svc := s3.New(cfg)
69 endpoint, region, err := determineEndpoint(ctx, svc, endpoint, bucket, configRegion)
70 if err != nil {
71 return bucketInfo{}, err
72 }
73 if region != "" {
74 svc.Config.WithRegion(region)
75 }
76 isAWS, endpoint, err := endpointIsOfficial(endpoint)
77 if err != nil {
78 return bucketInfo{}, err
79 }
80
81 if !isAWS {
82 return bucketInfo{
83 endpoint: endpoint,
84 isAWS: isAWS,
85 region: region,
86 }, nil
87 }
88
89
90 svc.Config.WithEndpoint(endpoint)
91 region, err = s3manager.GetBucketRegion(ctx, cfg, bucket, region)
92
93
94 if isAWS && err != nil {
95 return bucketInfo{}, err
96 }
97 return bucketInfo{
98 endpoint: endpoint,
99 isAWS: isAWS,
100 region: region,
101 }, nil
102 }
103
104
105
106
107
108
109
110 func determineEndpoint(ctx context.Context, svc s3iface.S3API, endpoint, bucket, region string) (string, string, error) {
111 req, _ := svc.ListObjectsV2Request(&s3.ListObjectsV2Input{
112 Bucket: &bucket,
113 MaxKeys: aws.Int64(1),
114 })
115 if region != "" {
116 req.ClientInfo.SigningRegion = region
117 }
118 req.Config.S3ForcePathStyle = aws.Bool(true)
119 req.DisableFollowRedirects = true
120 req.SetContext(ctx)
121
122 var determinedEndpoint string
123 req.Handlers.UnmarshalError.PushFront(func(r *request.Request) {
124 if r.HTTPResponse.StatusCode != http.StatusMovedPermanently {
125 return
126 }
127 var b bytes.Buffer
128 if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
129 r.Error = fmt.Errorf("error reading body: %v", err)
130 return
131 }
132
133 type endpointErr struct {
134 Endpoint string `xml:"Endpoint"`
135 }
136
137 var epErr endpointErr
138 err := xml.NewDecoder(&b).Decode(&epErr)
139 if err != nil {
140 r.Error = err
141 return
142 }
143 determinedEndpoint = epErr.Endpoint
144 r.HTTPResponse.Body = io.NopCloser(&b)
145 })
146 err := req.Send()
147 if determinedEndpoint == "" && err != nil {
148 if region == "" {
149
150 if newRegion := regionFromMalformedAuthHeaderError(err); newRegion != "" {
151
152 return determineEndpoint(ctx, svc, endpoint, bucket, newRegion)
153 }
154 }
155 return "", "", fmt.Errorf("s3: could not determine endpoint: %v", err)
156 }
157
158
159
160 if determinedEndpoint == "" {
161 return endpoint, region, nil
162 }
163
164 determinedEndpoint = strings.TrimPrefix(determinedEndpoint, bucket+".")
165 return determinedEndpoint, region, nil
166 }
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182 func endpointIsOfficial(endpoint string) (bool, string, error) {
183 for _, partition := range endpoints.DefaultPartitions() {
184 for _, region := range partition.Regions() {
185 s3Endpoint, err := region.ResolveEndpoint(endpoints.S3ServiceID)
186 if err != nil {
187
188 continue
189 }
190 p, err := url.Parse(s3Endpoint.URL)
191 if err != nil {
192 return false, endpoint, err
193 }
194
195 if strings.HasSuffix(endpoint, p.Host) {
196 return true, p.Host, nil
197 }
198 }
199 }
200 return false, endpoint, nil
201 }
202
203 var malformedAuthHeaderMessageRegexp = regexp.MustCompile("region '[^']+' is wrong; expecting '([^']+)'")
204
205
206
207
208
209
210
211 func regionFromMalformedAuthHeaderError(err error) string {
212 if aerr, ok := err.(awserr.Error); ok && aerr.Code() == "AuthorizationHeaderMalformed" {
213 matches := malformedAuthHeaderMessageRegexp.FindStringSubmatch(aerr.Message())
214 if len(matches) == 2 {
215 return matches[1]
216 }
217 log.Printf("s3: got AuthorizationHeaderMalformed, but couldn't parse message: %v", aerr.Message())
218 }
219 return ""
220 }