diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/feed/mail.go | 12 | ||||
-rw-r--r-- | internal/feed/parse.go | 47 | ||||
-rw-r--r-- | internal/http/client.go | 88 |
3 files changed, 104 insertions, 43 deletions
diff --git a/internal/feed/mail.go b/internal/feed/mail.go index 3cb442e..0360e10 100644 --- a/internal/feed/mail.go +++ b/internal/feed/mail.go @@ -7,7 +7,6 @@ import ( "io" "io/ioutil" "mime" - "net/http" "net/url" "path" "strings" @@ -19,6 +18,7 @@ import ( "github.com/gabriel-vasile/mimetype" "github.com/Necoro/feed2imap-go/internal/feed/template" + "github.com/Necoro/feed2imap-go/internal/http" "github.com/Necoro/feed2imap-go/internal/msg" "github.com/Necoro/feed2imap-go/pkg/config" "github.com/Necoro/feed2imap-go/pkg/log" @@ -193,12 +193,12 @@ func (feed *Feed) Messages() (msg.Messages, error) { return mails, nil } -func getImage(src string, client *http.Client) ([]byte, string, error) { - resp, err := client.Get(src) +func getImage(src string, timeout int, disableTLS bool) ([]byte, string, error) { + resp, cancel, err := http.Get(src, timeout, disableTLS) if err != nil { return nil, "", fmt.Errorf("fetching from '%s': %w", src, err) } - defer resp.Body.Close() + defer cancel() img, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -270,7 +270,7 @@ func (item *item) buildBody() { return } - srcUrl,err := url.Parse(src) + srcUrl, err := url.Parse(src) if err != nil { log.Errorf("Feed %s: Item %s: Error parsing URL '%s' embedded in item: %s", feed.Name, item.Item.Link, src, err) @@ -278,7 +278,7 @@ func (item *item) buildBody() { } imgUrl := feedUrl.ResolveReference(srcUrl) - img, mime, err := getImage(imgUrl.String(), httpClient(feed.NoTLS)) + img, mime, err := getImage(imgUrl.String(), feed.Global.Timeout, feed.NoTLS) if err != nil { log.Errorf("Feed %s: Item %s: Error fetching image: %s", feed.Name, item.Item.Link, err) diff --git a/internal/feed/parse.go b/internal/feed/parse.go index 1ba90fd..a8f705a 100644 --- a/internal/feed/parse.go +++ b/internal/feed/parse.go @@ -1,57 +1,30 @@ package feed import ( - ctxt "context" - "crypto/tls" "fmt" - "net/http" - "time" "github.com/google/uuid" "github.com/mmcdole/gofeed" + "github.com/Necoro/feed2imap-go/internal/http" "github.com/Necoro/feed2imap-go/pkg/log" ) -// share HTTP clients -var ( - stdHTTPClient *http.Client - unsafeHTTPClient *http.Client -) - -func init() { - // std - stdHTTPClient = &http.Client{Transport: http.DefaultTransport} - - // unsafe - tlsConfig := &tls.Config{InsecureSkipVerify: true} - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = tlsConfig - unsafeHTTPClient = &http.Client{Transport: transport} -} - -func context(timeout int) (ctxt.Context, ctxt.CancelFunc) { - return ctxt.WithTimeout(ctxt.Background(), time.Duration(timeout)*time.Second) -} - -func httpClient(disableTLS bool) *http.Client { - if disableTLS { - return unsafeHTTPClient - } - return stdHTTPClient -} - func (feed *Feed) parse() error { - ctx, cancel := context(feed.Global.Timeout) - defer cancel() - fp := gofeed.NewParser() - fp.Client = httpClient(feed.NoTLS) - parsedFeed, err := fp.ParseURLWithContext(feed.Url, ctx) + // we do not use the http support in gofeed, so that we can control the behavior of http requests + // and ensure it to be the same in all places + resp, cancel, err := http.Get(feed.Url, feed.Global.Timeout, feed.NoTLS) if err != nil { return fmt.Errorf("while fetching %s from %s: %w", feed.Name, feed.Url, err) } + defer cancel() // includes resp.Body.Close + + parsedFeed, err := fp.Parse(resp.Body) + if err != nil { + return fmt.Errorf("parsing feed '%s': %w", feed.Name, err) + } feed.feed = parsedFeed feed.items = make([]item, len(parsedFeed.Items)) diff --git a/internal/http/client.go b/internal/http/client.go new file mode 100644 index 0000000..c9af26e --- /dev/null +++ b/internal/http/client.go @@ -0,0 +1,88 @@ +package http + +import ( + ctxt "context" + "crypto/tls" + "fmt" + "net/http" + "time" +) + +// share HTTP clients +var ( + stdClient *http.Client + unsafeClient *http.Client +) + +// Error represents an HTTP error returned by a server. +type Error struct { + StatusCode int + Status string +} + +func (err Error) Error() string { + return fmt.Sprintf("http error: %s", err.Status) +} + +func init() { + // std + stdClient = &http.Client{Transport: http.DefaultTransport} + + // unsafe + tlsConfig := &tls.Config{InsecureSkipVerify: true} + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = tlsConfig + unsafeClient = &http.Client{Transport: transport} +} + +func context(timeout int) (ctxt.Context, ctxt.CancelFunc) { + return ctxt.WithTimeout(ctxt.Background(), time.Duration(timeout)*time.Second) +} + +func client(disableTLS bool) *http.Client { + if disableTLS { + return unsafeClient + } + return stdClient +} + +var noop ctxt.CancelFunc = func() {} + +func Get(url string, timeout int, disableTLS bool) (resp *http.Response, cancel ctxt.CancelFunc, err error) { + prematureExit := true + ctx, ctxCancel := context(timeout) + + cancel = func() { + if resp != nil { + _ = resp.Body.Close() + } + ctxCancel() + } + + defer func() { + if prematureExit { + cancel() + } + }() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, noop, err + } + req.Header.Set("User-Agent", "Feed2Imap-Go/1.0") + + resp, err = client(disableTLS).Do(req) + if err != nil { + return nil, noop, err + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, noop, Error{ + StatusCode: resp.StatusCode, + Status: resp.Status, + } + } + + prematureExit = false + return resp, cancel, nil +} |