From 81d840699aadf901a2476a3637fa2ebe59302dec Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 13 Apr 2017 20:14:07 +0200 Subject: [PATCH] mutex protected close --- app/proxyman/mux/mux.go | 13 ++++++------- app/proxyman/mux/session.go | 30 +++++++++++++++++++++++++++--- app/proxyman/mux/session_test.go | 6 ++---- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index f2a4265df..5e8763119 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -121,13 +121,12 @@ func (m *Client) monitor() { for { select { case <-m.ctx.Done(): - m.sessionManager.Close() m.inboundRay.InboundInput().Close() m.inboundRay.InboundOutput().CloseError() return case <-time.After(time.Second * 6): size := m.sessionManager.Size() - if size == 0 { + if size == 0 && m.sessionManager.CloseIfNoSession() { m.cancel() } } @@ -169,12 +168,12 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool default: } - s := &Session{ - input: outboundRay.OutboundInput(), - output: outboundRay.OutboundOutput(), - parent: m.sessionManager, + s := m.sessionManager.Allocate() + if s == nil { + return false } - m.sessionManager.Allocate(s) + s.input = outboundRay.OutboundInput() + s.output = outboundRay.OutboundOutput() go fetchInput(ctx, s, m.inboundRay.InboundInput()) return true } diff --git a/app/proxyman/mux/session.go b/app/proxyman/mux/session.go index 33e9efee8..f66065aef 100644 --- a/app/proxyman/mux/session.go +++ b/app/proxyman/mux/session.go @@ -10,6 +10,7 @@ type SessionManager struct { sync.RWMutex count uint16 sessions map[uint16]*Session + closed bool } func NewSessionManager() *SessionManager { @@ -26,13 +27,21 @@ func (m *SessionManager) Size() int { return len(m.sessions) } -func (m *SessionManager) Allocate(s *Session) { +func (m *SessionManager) Allocate() *Session { m.Lock() defer m.Unlock() + if m.closed { + return nil + } + m.count++ - s.ID = m.count + s := &Session{ + ID: m.count, + parent: m, + } m.sessions[s.ID] = s + return s } func (m *SessionManager) Add(s *Session) { @@ -53,17 +62,32 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) { m.RLock() defer m.RUnlock() + if m.closed { + return nil, false + } + s, found := m.sessions[id] return s, found } -func (m *SessionManager) Close() { +func (m *SessionManager) CloseIfNoSession() bool { m.RLock() defer m.RUnlock() + if len(m.sessions) == 0 { + return false + } + + m.closed = true + for _, s := range m.sessions { + s.input.CloseError() s.output.CloseError() } + + m.sessions = make(map[uint16]*Session) + + return true } type Session struct { diff --git a/app/proxyman/mux/session_test.go b/app/proxyman/mux/session_test.go index e57de2f59..6ceb951c2 100644 --- a/app/proxyman/mux/session_test.go +++ b/app/proxyman/mux/session_test.go @@ -12,12 +12,10 @@ func TestSessionManagerAdd(t *testing.T) { m := NewSessionManager() - s := &Session{} - m.Allocate(s) + s := m.Allocate() assert.Uint16(s.ID).Equals(1) - s = &Session{} - m.Allocate(s) + s = m.Allocate() assert.Uint16(s.ID).Equals(2) s = &Session{