aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--internal/feed/mail.go12
-rw-r--r--internal/feed/parse.go47
-rw-r--r--internal/http/client.go88
3 files changed, 104 insertions, 43 deletions
diff --git a/internal/feed/mail.go b/internal/feed/mail.go
index 3cb442e..0360e10 100644
--- a/internal/feed/mail.go
+++ b/internal/feed/mail.go
@@ -7,7 +7,6 @@ import (
"io"
"io/ioutil"
"mime"
- "net/http"
"net/url"
"path"
"strings"
@@ -19,6 +18,7 @@ import (
"github.com/gabriel-vasile/mimetype"
"github.com/Necoro/feed2imap-go/internal/feed/template"
+ "github.com/Necoro/feed2imap-go/internal/http"
"github.com/Necoro/feed2imap-go/internal/msg"
"github.com/Necoro/feed2imap-go/pkg/config"
"github.com/Necoro/feed2imap-go/pkg/log"
@@ -193,12 +193,12 @@ func (feed *Feed) Messages() (msg.Messages, error) {
return mails, nil
}
-func getImage(src string, client *http.Client) ([]byte, string, error) {
- resp, err := client.Get(src)
+func getImage(src string, timeout int, disableTLS bool) ([]byte, string, error) {
+ resp, cancel, err := http.Get(src, timeout, disableTLS)
if err != nil {
return nil, "", fmt.Errorf("fetching from '%s': %w", src, err)
}
- defer resp.Body.Close()
+ defer cancel()
img, err := ioutil.ReadAll(resp.Body)
if err != nil {
@@ -270,7 +270,7 @@ func (item *item) buildBody() {
return
}
- srcUrl,err := url.Parse(src)
+ srcUrl, err := url.Parse(src)
if err != nil {
log.Errorf("Feed %s: Item %s: Error parsing URL '%s' embedded in item: %s",
feed.Name, item.Item.Link, src, err)
@@ -278,7 +278,7 @@ func (item *item) buildBody() {
}
imgUrl := feedUrl.ResolveReference(srcUrl)
- img, mime, err := getImage(imgUrl.String(), httpClient(feed.NoTLS))
+ img, mime, err := getImage(imgUrl.String(), feed.Global.Timeout, feed.NoTLS)
if err != nil {
log.Errorf("Feed %s: Item %s: Error fetching image: %s",
feed.Name, item.Item.Link, err)
diff --git a/internal/feed/parse.go b/internal/feed/parse.go
index 1ba90fd..a8f705a 100644
--- a/internal/feed/parse.go
+++ b/internal/feed/parse.go
@@ -1,57 +1,30 @@
package feed
import (
- ctxt "context"
- "crypto/tls"
"fmt"
- "net/http"
- "time"
"github.com/google/uuid"
"github.com/mmcdole/gofeed"
+ "github.com/Necoro/feed2imap-go/internal/http"
"github.com/Necoro/feed2imap-go/pkg/log"
)
-// share HTTP clients
-var (
- stdHTTPClient *http.Client
- unsafeHTTPClient *http.Client
-)
-
-func init() {
- // std
- stdHTTPClient = &http.Client{Transport: http.DefaultTransport}
-
- // unsafe
- tlsConfig := &tls.Config{InsecureSkipVerify: true}
- transport := http.DefaultTransport.(*http.Transport).Clone()
- transport.TLSClientConfig = tlsConfig
- unsafeHTTPClient = &http.Client{Transport: transport}
-}
-
-func context(timeout int) (ctxt.Context, ctxt.CancelFunc) {
- return ctxt.WithTimeout(ctxt.Background(), time.Duration(timeout)*time.Second)
-}
-
-func httpClient(disableTLS bool) *http.Client {
- if disableTLS {
- return unsafeHTTPClient
- }
- return stdHTTPClient
-}
-
func (feed *Feed) parse() error {
- ctx, cancel := context(feed.Global.Timeout)
- defer cancel()
-
fp := gofeed.NewParser()
- fp.Client = httpClient(feed.NoTLS)
- parsedFeed, err := fp.ParseURLWithContext(feed.Url, ctx)
+ // we do not use the http support in gofeed, so that we can control the behavior of http requests
+ // and ensure it to be the same in all places
+ resp, cancel, err := http.Get(feed.Url, feed.Global.Timeout, feed.NoTLS)
if err != nil {
return fmt.Errorf("while fetching %s from %s: %w", feed.Name, feed.Url, err)
}
+ defer cancel() // includes resp.Body.Close
+
+ parsedFeed, err := fp.Parse(resp.Body)
+ if err != nil {
+ return fmt.Errorf("parsing feed '%s': %w", feed.Name, err)
+ }
feed.feed = parsedFeed
feed.items = make([]item, len(parsedFeed.Items))
diff --git a/internal/http/client.go b/internal/http/client.go
new file mode 100644
index 0000000..c9af26e
--- /dev/null
+++ b/internal/http/client.go
@@ -0,0 +1,88 @@
+package http
+
+import (
+ ctxt "context"
+ "crypto/tls"
+ "fmt"
+ "net/http"
+ "time"
+)
+
+// share HTTP clients
+var (
+ stdClient *http.Client
+ unsafeClient *http.Client
+)
+
+// Error represents an HTTP error returned by a server.
+type Error struct {
+ StatusCode int
+ Status string
+}
+
+func (err Error) Error() string {
+ return fmt.Sprintf("http error: %s", err.Status)
+}
+
+func init() {
+ // std
+ stdClient = &http.Client{Transport: http.DefaultTransport}
+
+ // unsafe
+ tlsConfig := &tls.Config{InsecureSkipVerify: true}
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+ transport.TLSClientConfig = tlsConfig
+ unsafeClient = &http.Client{Transport: transport}
+}
+
+func context(timeout int) (ctxt.Context, ctxt.CancelFunc) {
+ return ctxt.WithTimeout(ctxt.Background(), time.Duration(timeout)*time.Second)
+}
+
+func client(disableTLS bool) *http.Client {
+ if disableTLS {
+ return unsafeClient
+ }
+ return stdClient
+}
+
+var noop ctxt.CancelFunc = func() {}
+
+func Get(url string, timeout int, disableTLS bool) (resp *http.Response, cancel ctxt.CancelFunc, err error) {
+ prematureExit := true
+ ctx, ctxCancel := context(timeout)
+
+ cancel = func() {
+ if resp != nil {
+ _ = resp.Body.Close()
+ }
+ ctxCancel()
+ }
+
+ defer func() {
+ if prematureExit {
+ cancel()
+ }
+ }()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, noop, err
+ }
+ req.Header.Set("User-Agent", "Feed2Imap-Go/1.0")
+
+ resp, err = client(disableTLS).Do(req)
+ if err != nil {
+ return nil, noop, err
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, noop, Error{
+ StatusCode: resp.StatusCode,
+ Status: resp.Status,
+ }
+ }
+
+ prematureExit = false
+ return resp, cancel, nil
+}