155 lines
3.4 KiB
Go
155 lines
3.4 KiB
Go
|
package sftp
|
||
|
|
||
|
import (
|
||
|
"encoding"
|
||
|
"fmt"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
)
|
||
|
|
||
|
type _testSender struct {
|
||
|
sent chan encoding.BinaryMarshaler
|
||
|
}
|
||
|
|
||
|
func newTestSender() *_testSender {
|
||
|
return &_testSender{make(chan encoding.BinaryMarshaler)}
|
||
|
}
|
||
|
|
||
|
func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
|
||
|
s.sent <- p
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type fakepacket uint32
|
||
|
|
||
|
func (fakepacket) MarshalBinary() ([]byte, error) {
|
||
|
return []byte{}, nil
|
||
|
}
|
||
|
|
||
|
func (fakepacket) UnmarshalBinary([]byte) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (f fakepacket) id() uint32 {
|
||
|
return uint32(f)
|
||
|
}
|
||
|
|
||
|
type pair struct {
|
||
|
in fakepacket
|
||
|
out fakepacket
|
||
|
}
|
||
|
|
||
|
// basic test
|
||
|
var ttable1 = []pair{
|
||
|
pair{fakepacket(0), fakepacket(0)},
|
||
|
pair{fakepacket(1), fakepacket(1)},
|
||
|
pair{fakepacket(2), fakepacket(2)},
|
||
|
pair{fakepacket(3), fakepacket(3)},
|
||
|
}
|
||
|
|
||
|
// outgoing packets out of order
|
||
|
var ttable2 = []pair{
|
||
|
pair{fakepacket(0), fakepacket(0)},
|
||
|
pair{fakepacket(1), fakepacket(4)},
|
||
|
pair{fakepacket(2), fakepacket(1)},
|
||
|
pair{fakepacket(3), fakepacket(3)},
|
||
|
pair{fakepacket(4), fakepacket(2)},
|
||
|
}
|
||
|
|
||
|
// incoming packets out of order
|
||
|
var ttable3 = []pair{
|
||
|
pair{fakepacket(2), fakepacket(0)},
|
||
|
pair{fakepacket(1), fakepacket(1)},
|
||
|
pair{fakepacket(3), fakepacket(2)},
|
||
|
pair{fakepacket(0), fakepacket(3)},
|
||
|
}
|
||
|
|
||
|
var tables = [][]pair{ttable1, ttable2, ttable3}
|
||
|
|
||
|
func TestPacketManager(t *testing.T) {
|
||
|
sender := newTestSender()
|
||
|
s := newPktMgr(sender)
|
||
|
|
||
|
for i := range tables {
|
||
|
table := tables[i]
|
||
|
for _, p := range table {
|
||
|
s.incomingPacket(p.in)
|
||
|
}
|
||
|
for _, p := range table {
|
||
|
s.readyPacket(p.out)
|
||
|
}
|
||
|
for i := 0; i < len(table); i++ {
|
||
|
pkt := <-sender.sent
|
||
|
id := pkt.(fakepacket).id()
|
||
|
assert.Equal(t, id, uint32(i))
|
||
|
}
|
||
|
}
|
||
|
s.close()
|
||
|
}
|
||
|
|
||
|
func (p sshFxpRemovePacket) String() string {
|
||
|
return fmt.Sprintf("RmPct:%d", p.ID)
|
||
|
}
|
||
|
func (p sshFxpOpenPacket) String() string {
|
||
|
return fmt.Sprintf("OpPct:%d", p.ID)
|
||
|
}
|
||
|
func (p sshFxpWritePacket) String() string {
|
||
|
return fmt.Sprintf("WrPct:%d", p.ID)
|
||
|
}
|
||
|
func (p sshFxpClosePacket) String() string {
|
||
|
return fmt.Sprintf("ClPct:%d", p.ID)
|
||
|
}
|
||
|
|
||
|
// Test what happens when the pool processes a close packet on a file that it
|
||
|
// is still reading from.
|
||
|
func TestCloseOutOfOrder(t *testing.T) {
|
||
|
packets := []requestPacket{
|
||
|
&sshFxpRemovePacket{ID: 0, Filename: "foo"},
|
||
|
&sshFxpOpenPacket{ID: 1},
|
||
|
&sshFxpWritePacket{ID: 2, Handle: "foo"},
|
||
|
&sshFxpWritePacket{ID: 3, Handle: "foo"},
|
||
|
&sshFxpWritePacket{ID: 4, Handle: "foo"},
|
||
|
&sshFxpWritePacket{ID: 5, Handle: "foo"},
|
||
|
&sshFxpClosePacket{ID: 6, Handle: "foo"},
|
||
|
&sshFxpRemovePacket{ID: 7, Filename: "foo"},
|
||
|
}
|
||
|
|
||
|
recvChan := make(chan requestPacket, len(packets)+1)
|
||
|
sender := newTestSender()
|
||
|
pktMgr := newPktMgr(sender)
|
||
|
wg := sync.WaitGroup{}
|
||
|
wg.Add(len(packets))
|
||
|
runWorker := func(ch requestChan) {
|
||
|
go func() {
|
||
|
for pkt := range ch {
|
||
|
if _, ok := pkt.(*sshFxpWritePacket); ok {
|
||
|
// sleep to cause writes to come after close/remove
|
||
|
time.Sleep(time.Millisecond)
|
||
|
}
|
||
|
pktMgr.working.Done()
|
||
|
recvChan <- pkt
|
||
|
wg.Done()
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
pktChan := pktMgr.workerChan(runWorker)
|
||
|
for _, p := range packets {
|
||
|
pktChan <- p
|
||
|
}
|
||
|
wg.Wait()
|
||
|
close(recvChan)
|
||
|
received := []requestPacket{}
|
||
|
for p := range recvChan {
|
||
|
received = append(received, p)
|
||
|
}
|
||
|
if received[len(received)-2].id() != packets[len(packets)-2].id() {
|
||
|
t.Fatal("Packets processed out of order1:", received, packets)
|
||
|
}
|
||
|
if received[len(received)-1].id() != packets[len(packets)-1].id() {
|
||
|
t.Fatal("Packets processed out of order2:", received, packets)
|
||
|
}
|
||
|
}
|