From 3a854c3bc47e75491b836c7fc12b617da5d68288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20=27Necoro=27=20Neumann?= Date: Tue, 16 Feb 2021 00:30:13 +0100 Subject: Issue #46: Fix semantics of `n` result Per contract, the returned number of bytes written should be the number of bytes _from the input_. Therefore, the added bytes (`\r` or `\n`) shall not count into that number. --- pkg/rfc822/writer.go | 18 ++++++++++-------- pkg/rfc822/writer_test.go | 6 +++++- 2 files changed, 15 insertions(+), 9 deletions(-) (limited to 'pkg') diff --git a/pkg/rfc822/writer.go b/pkg/rfc822/writer.go index 07751ea..dd96fbb 100644 --- a/pkg/rfc822/writer.go +++ b/pkg/rfc822/writer.go @@ -23,29 +23,31 @@ func (f rfc822Writer) Write(p []byte) (n int, err error) { crFound := false start := 0 - write := func(str []byte) { + write := func(str []byte, count bool) { var j int j, err = f.w.Write(str) - n = n + j + if count { + n += j + } } for idx, b := range p { if crFound && b != '\n' { // insert '\n' - if write(p[start:idx]); err != nil { + if write(p[start:idx], true); err != nil { return } - if write(lf); err != nil { + if write(lf, false); err != nil { return } start = idx } else if !crFound && b == '\n' { // insert '\r' - if write(p[start:idx]); err != nil { + if write(p[start:idx], true); err != nil { return } - if write(cr); err != nil { + if write(cr, false); err != nil { return } @@ -55,12 +57,12 @@ func (f rfc822Writer) Write(p []byte) (n int, err error) { } // write the remainder - if write(p[start:]); err != nil { + if write(p[start:], true); err != nil { return } if crFound { // dangling \r - write(lf) + write(lf, false) } return diff --git a/pkg/rfc822/writer_test.go b/pkg/rfc822/writer_test.go index 7beae8d..34c9d4a 100644 --- a/pkg/rfc822/writer_test.go +++ b/pkg/rfc822/writer_test.go @@ -34,10 +34,14 @@ func TestRfc822Writer_Write(t *testing.T) { t.Run(tt.before, func(t *testing.T) { b := bytes.Buffer{} w := Writer(&b) - if _, err := io.WriteString(w, tt.before); err != nil { + n, err := io.WriteString(w, tt.before) + if err != nil { t.Errorf("Error: %v", err) return } + if n != len(tt.before) { + t.Errorf("Unexpected number of bytes written: %d, expected: %d", n, len(tt.before)) + } res := b.String() if tt.after != res { t.Errorf("Expected: %q, got: %q", tt.after, res) -- cgit v1.2.3