mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2024-11-10 04:05:42 +01:00
[REFACTOR] PKT protocol
- Use `Fprintf` to convert to hex and do padding. Simplifies the code. - Use `Read()` and `io.ReadFull` instead of `ReadByte()`. Should improve performance and allows for cleaner code. - s/pktLineTypeUnknow/pktLineTypeUnknown. - Disallow empty Pkt line per the specification. - Disallow too large Pkt line per the specification. - Add unit tests.
This commit is contained in:
parent
a11116602e
commit
2c8bcc163e
3 changed files with 93 additions and 52 deletions
54
cmd/hook.go
54
cmd/hook.go
|
@ -583,7 +583,7 @@ Forgejo or set your environment appropriately.`, "")
|
|||
|
||||
for {
|
||||
// note: pktLineTypeUnknow means pktLineTypeFlush and pktLineTypeData all allowed
|
||||
rs, err = readPktLine(ctx, reader, pktLineTypeUnknow)
|
||||
rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -604,7 +604,7 @@ Forgejo or set your environment appropriately.`, "")
|
|||
|
||||
if hasPushOptions {
|
||||
for {
|
||||
rs, err = readPktLine(ctx, reader, pktLineTypeUnknow)
|
||||
rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -699,8 +699,8 @@ Forgejo or set your environment appropriately.`, "")
|
|||
type pktLineType int64
|
||||
|
||||
const (
|
||||
// UnKnow type
|
||||
pktLineTypeUnknow pktLineType = 0
|
||||
// Unknown type
|
||||
pktLineTypeUnknown pktLineType = 0
|
||||
// flush-pkt "0000"
|
||||
pktLineTypeFlush pktLineType = iota
|
||||
// data line
|
||||
|
@ -714,22 +714,16 @@ type gitPktLine struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
// Reads an Pkt-Line from `in`. If requestType is not unknown, it will a
|
||||
func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType) (*gitPktLine, error) {
|
||||
var (
|
||||
err error
|
||||
r *gitPktLine
|
||||
)
|
||||
|
||||
// read prefix
|
||||
// Read length prefix
|
||||
lengthBytes := make([]byte, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
lengthBytes[i], err = in.ReadByte()
|
||||
if err != nil {
|
||||
return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
|
||||
}
|
||||
if n, err := in.Read(lengthBytes); n != 4 || err != nil {
|
||||
return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
|
||||
}
|
||||
|
||||
r = new(gitPktLine)
|
||||
var err error
|
||||
r := &gitPktLine{}
|
||||
r.Length, err = strconv.ParseUint(string(lengthBytes), 16, 32)
|
||||
if err != nil {
|
||||
return nil, fail(ctx, "Protocol: format parse error", "Pkt-Line format is wrong :%v", err)
|
||||
|
@ -748,11 +742,8 @@ func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType)
|
|||
}
|
||||
|
||||
r.Data = make([]byte, r.Length-4)
|
||||
for i := range r.Data {
|
||||
r.Data[i], err = in.ReadByte()
|
||||
if err != nil {
|
||||
return nil, fail(ctx, "Protocol: data error", "Pkt-Line: read stdin failed : %v", err)
|
||||
}
|
||||
if n, err := io.ReadFull(in, r.Data); uint64(n) != r.Length-4 || err != nil {
|
||||
return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
|
||||
}
|
||||
|
||||
r.Type = pktLineTypeData
|
||||
|
@ -768,20 +759,23 @@ func writeFlushPktLine(ctx context.Context, out io.Writer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Write an Pkt-Line based on `data` to `out` according to the specifcation.
|
||||
// https://git-scm.com/docs/protocol-common
|
||||
func writeDataPktLine(ctx context.Context, out io.Writer, data []byte) error {
|
||||
hexchar := []byte("0123456789abcdef")
|
||||
hex := func(n uint64) byte {
|
||||
return hexchar[(n)&15]
|
||||
// Implementations SHOULD NOT send an empty pkt-line ("0004").
|
||||
if len(data) == 0 {
|
||||
return fail(ctx, "Protocol: write error", "Not allowed to write empty Pkt-Line")
|
||||
}
|
||||
|
||||
length := uint64(len(data) + 4)
|
||||
tmp := make([]byte, 4)
|
||||
tmp[0] = hex(length >> 12)
|
||||
tmp[1] = hex(length >> 8)
|
||||
tmp[2] = hex(length >> 4)
|
||||
tmp[3] = hex(length)
|
||||
|
||||
lr, err := out.Write(tmp)
|
||||
// The maximum length of a pkt-line’s data component is 65516 bytes.
|
||||
// Implementations MUST NOT send pkt-line whose length exceeds 65520 (65516 bytes of payload + 4 bytes of length data).
|
||||
if length > 65520 {
|
||||
return fail(ctx, "Protocol: write error", "Pkt-Line exceeds maximum of 65520 bytes")
|
||||
}
|
||||
|
||||
lr, err := fmt.Fprintf(out, "%04x", length)
|
||||
if err != nil || lr != 4 {
|
||||
return fail(ctx, "Protocol: write error", "Pkt-Line response failed: %v", err)
|
||||
}
|
||||
|
|
|
@ -14,29 +14,72 @@ import (
|
|||
)
|
||||
|
||||
func TestPktLine(t *testing.T) {
|
||||
// test read
|
||||
ctx := context.Background()
|
||||
s := strings.NewReader("0000")
|
||||
r := bufio.NewReader(s)
|
||||
result, err := readPktLine(ctx, r, pktLineTypeFlush)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pktLineTypeFlush, result.Type)
|
||||
|
||||
s = strings.NewReader("0006a\n")
|
||||
r = bufio.NewReader(s)
|
||||
result, err = readPktLine(ctx, r, pktLineTypeData)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pktLineTypeData, result.Type)
|
||||
assert.Equal(t, []byte("a\n"), result.Data)
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
s := strings.NewReader("0000")
|
||||
r := bufio.NewReader(s)
|
||||
result, err := readPktLine(ctx, r, pktLineTypeFlush)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pktLineTypeFlush, result.Type)
|
||||
|
||||
// test write
|
||||
w := bytes.NewBuffer([]byte{})
|
||||
err = writeFlushPktLine(ctx, w)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("0000"), w.Bytes())
|
||||
s = strings.NewReader("0006a\n")
|
||||
r = bufio.NewReader(s)
|
||||
result, err = readPktLine(ctx, r, pktLineTypeData)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pktLineTypeData, result.Type)
|
||||
assert.Equal(t, []byte("a\n"), result.Data)
|
||||
|
||||
w.Reset()
|
||||
err = writeDataPktLine(ctx, w, []byte("a\nb"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("0007a\nb"), w.Bytes())
|
||||
s = strings.NewReader("0004")
|
||||
r = bufio.NewReader(s)
|
||||
result, err = readPktLine(ctx, r, pktLineTypeData)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
|
||||
data := strings.Repeat("x", 65516)
|
||||
r = bufio.NewReader(strings.NewReader("fff0" + data))
|
||||
result, err = readPktLine(ctx, r, pktLineTypeData)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pktLineTypeData, result.Type)
|
||||
assert.Equal(t, []byte(data), result.Data)
|
||||
|
||||
r = bufio.NewReader(strings.NewReader("fff1a"))
|
||||
result, err = readPktLine(ctx, r, pktLineTypeData)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("Write", func(t *testing.T) {
|
||||
w := bytes.NewBuffer([]byte{})
|
||||
err := writeFlushPktLine(ctx, w)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("0000"), w.Bytes())
|
||||
|
||||
w.Reset()
|
||||
err = writeDataPktLine(ctx, w, []byte("a\nb"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("0007a\nb"), w.Bytes())
|
||||
|
||||
w.Reset()
|
||||
data := bytes.Repeat([]byte{0x05}, 288)
|
||||
err = writeDataPktLine(ctx, w, data)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, append([]byte("0124"), data...), w.Bytes())
|
||||
|
||||
w.Reset()
|
||||
err = writeDataPktLine(ctx, w, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, w.Bytes())
|
||||
|
||||
w.Reset()
|
||||
data = bytes.Repeat([]byte{0x64}, 65516)
|
||||
err = writeDataPktLine(ctx, w, data)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, append([]byte("fff0"), data...), w.Bytes())
|
||||
|
||||
w.Reset()
|
||||
err = writeDataPktLine(ctx, w, bytes.Repeat([]byte{0x64}, 65516+1))
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, w.Bytes())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
|
@ -106,7 +107,10 @@ func fail(ctx context.Context, userMessage, logMsgFmt string, args ...any) error
|
|||
logMsg = userMessage + ". " + logMsg
|
||||
}
|
||||
}
|
||||
_ = private.SSHLog(ctx, true, logMsg)
|
||||
// Don't send an log if this is done in a test and no InternalToken is set.
|
||||
if !testing.Testing() || setting.InternalToken != "" {
|
||||
_ = private.SSHLog(ctx, true, logMsg)
|
||||
}
|
||||
}
|
||||
return cli.Exit("", 1)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue