当前位置: 首页 > 工具软件 > Quod Libet > 使用案例 >

基于select模型的windows网络库libet(三)EventBase

易刚捷
2023-12-01

EventBase.h

#pragma once
#include <memory>
#include <WinSock2.h>

#include "Global.h"
#include "Threads.h"
#include "Buffer.h"

struct EventBases : private noncopyable 
{
	virtual EventBase* AllocBase() = 0;
};

//事件管理器
struct EventBase : public EventBases
{
	EventBase(int taskCapacity = 0);
	~EventBase();

	//进入事件处理循环
	void Loop();
	void LoopOnce(int waitMs);
	//取消定时任务,若timer已经过期,则忽略
	bool Cancel(TimerId timerid);
	//定时任务,任务在HandleTimeouts中执行
	TimerId RunAt(int64_t milli, const Task &task, int64_t interval = 0) { return RunAt(milli, Task(task), interval); }
	TimerId RunAt(int64_t milli, Task &&task, int64_t interval = 0);
	TimerId RunAfter(int64_t milli, const Task &task, int64_t interval = 0) { return RunAt(util::timeMilli() + milli, Task(task), interval); }
	TimerId RunAfter(int64_t milli, Task &&task, int64_t interval = 0) { return RunAt(util::timeMilli() + milli, std::move(task), interval); }
	TimerId RunEvery(const Task &task, int64_t interval = 0) { return RunAt(util::timeMilli(), Task(task), interval); }
	TimerId RunEvery(Task &&task, int64_t interval = 0) { return RunAt(util::timeMilli(), std::move(task), interval); }

	//退出事件循环
	EventBase& Exit();
	bool Exited();
	void Wakeup();
	void SafeCall(Task &&task);
	void SafeCall(const Task &task) { SafeCall(Task(task)); }
	virtual EventBase* AllocBase() override { return this; }	

	void AddConn(TcpConnPtr&& conn);
	void AddConn(const TcpConnPtr& conn);
	void RemoveConn(const TcpConnPtr& conn);

	void PostDataToOneConn(Buffer& buf, const std::weak_ptr<TcpConn>& wpConn);
	void PostDataToAllConns(Buffer& buf);


public:
	int CreatePipe(SOCKET fds[2]);
	void Init();
	//闲置时间到达的连接执行回调任务,在HandleTimeouts中执行
	void CallIdles();

	IdleId RegisterIdle(int idle, const TcpConnPtr &con, const TcpCallBack &cb);
	void UnregisterIdle(const IdleId &id);
	void UpdateIdle(const IdleId &id);
	void HandleTimeouts();
	void RefreshNearest(const TimerId *tid = nullptr);
	void RepeatableTimeout(TimerRepeatable *tr);

	//void SendThreadProc();

public:
	PollerBase* m_poller;
	std::atomic<bool> m_exit;
	SOCKET m_wakeupFds[2];
	int m_nextTimeout;
	SafeQueue<Task> m_tasks;
	std::unique_ptr<ConnThreadPool> m_upConnThreadPool;

	std::map<TimerId, TimerRepeatable> m_timerReps;
	std::map<TimerId, Task> m_timers;
	std::atomic<int64_t> m_timerSeq;	//定时器序号
	// 记录每个idle时间点(单位秒)下所有的连接。链表中的所有连接,最新的插入到链表末尾。连接若有活动,会把连接从链表中移到链表尾部,做法参考memcache
	std::map<int, std::list<IdleNode>> m_idleConns;
	bool m_idleEnabled;

	SafeList<TcpConnPtr> m_conns;
};

//多线程的事件派发器
struct MultiBase : public EventBases 
{
	MultiBase(int sz) : m_id(0), m_bases(sz) {}
	virtual EventBase *AllocBase() override;
	void Loop();
	MultiBase &Exit();

private:
	std::atomic<int> m_id;
	std::vector<EventBase> m_bases;
};

struct Channel : private noncopyable
{
	Channel(EventBase* base, int fd);
	~Channel();

	EventBase* GetBase() { return m_base; }
	SOCKET fd() { return m_fd; }
	int64_t id() { return m_id; }

	/* 
	 * 被动关闭连接:HandleRead->Cleanup->~Channel->Channel::Close->HandleRead->~TcpConn
	 * 主动关闭连接:TcpConn::Close->Channel::Close->HandleRead->Cleanup->~Channel->Channel::Close->~TcpConn
	 * TryDecode失败:Channel::Close->HandleRead->Cleanup->~Channel->Channel::Close->~TcpConn
	*/
	void Close();

	//挂接事件处理器
	void OnRead(const Task &readcb) { m_readCb = readcb; }
	void OnWrite(const Task &writecb) { m_writeCb = writecb; }
	void OnRead(Task &&readcb) { m_readCb = std::move(readcb); }
	void OnWrite(Task &&writecb) { m_writeCb = std::move(writecb); }

	//处理读写事件
	void HandleRead() { m_readCb(); }
	void HandleWrite() { m_writeCb(); }

	void EnableRead(bool enable);
	void EnableWrite(bool enable);
	bool Readable() { return m_readable; }
	bool Writable() { return m_writable; }

protected:
	SOCKET m_fd;							
	EventBase* m_base;
	int64_t m_id;
	bool m_readable, m_writable;
	std::function<void()> m_readCb, m_writeCb, m_errorCb;
};

EventBase.cpp

#include <iostream>

#include "EventBase.h"
#include "PollerBase.h"
#include "TcpConn.h"
#include "Timer.h"
#include "Logger.h"

using namespace std;

EventBase::EventBase(int taskCapacity)
	: m_poller(CreatePoller()), m_exit(false), m_tasks(taskCapacity),
	m_nextTimeout(1 << 30), m_timerSeq(0), m_idleEnabled(false)
{
	//根据CPU核数创建工作线程
	SYSTEM_INFO sysInfo;
	GetSystemInfo(&sysInfo);
	m_upConnThreadPool = make_unique<ConnThreadPool>(sysInfo.dwNumberOfProcessors);

	m_wakeupFds[0] = SOCKET_ERROR;
	m_wakeupFds[1] = SOCKET_ERROR;
	Init();
}

EventBase::~EventBase()
{
	m_upConnThreadPool->Exit();
	m_upConnThreadPool->Join();

	delete m_poller;

	::closesocket(m_wakeupFds[1]);
	WSACleanup();
}

void EventBase::AddConn(TcpConnPtr&& conn)
{
	m_conns.EmplaceBack(std::move(conn));
}

void EventBase::AddConn(const TcpConnPtr & conn)
{
	m_conns.EmplaceBack(conn);
}

void EventBase::RemoveConn(const TcpConnPtr& conn)
{
	m_conns.Remove(conn);
}

void EventBase::PostDataToOneConn(Buffer & buf, const std::weak_ptr<TcpConn>& wpConn)
{
	if (auto con = wpConn.lock())
	{
		m_upConnThreadPool->AddTask(con->GetChannel()->id(), [con, buf]() mutable
		{
			con->SendMsg(buf);
		});
	}
}

void EventBase::PostDataToAllConns(Buffer& buf)
{
	function<void(const TcpConnPtr& con)> task = [this, &buf](const TcpConnPtr& con)
	{
		m_upConnThreadPool->AddTask(con->GetChannel()->id(), [con, buf]() mutable
		{
			con->SendMsg(buf);
		});
	};

	m_conns.ForEach(task);

}

int EventBase::CreatePipe(SOCKET fds[2])
{
	SOCKET tcp1 = SOCKET_ERROR, tcp2 = SOCKET_ERROR;
	sockaddr_in name;
	memset(&name, 0, sizeof(name));
	name.sin_family = AF_INET;
	name.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
	name.sin_port = 0;
	int namelen = sizeof(name);
	
	SOCKET tcp = ::socket(AF_INET, SOCK_STREAM, 0);
	if (tcp == SOCKET_ERROR) 
		goto clean;
	
	if (::bind(tcp, (sockaddr*)&name, namelen) == SOCKET_ERROR) 
		goto clean;
	
	if (::listen(tcp, 5) == SOCKET_ERROR) 
		goto clean;
	
	if (getsockname(tcp, (sockaddr*)&name, &namelen) == SOCKET_ERROR) 
		goto clean;
	
	tcp1 = ::socket(AF_INET, SOCK_STREAM, 0);
	if (tcp1 == SOCKET_ERROR) 
		goto clean;
	
	if (SOCKET_ERROR == ::connect(tcp1, (sockaddr*)&name, namelen)) 
		goto clean;
	
	tcp2 = ::accept(tcp, nullptr, nullptr);
	if (tcp2 == SOCKET_ERROR) 
		goto clean;
	
	if (::closesocket(tcp) == SOCKET_ERROR) 
		goto clean;
	
	fds[0] = tcp1;
	fds[1] = tcp2;
	return 0;
clean:
	LCritical("Create Pipe failed,code:{}", errcode);
	if (tcp != SOCKET_ERROR)
		::closesocket(tcp);
	
	if (tcp2 != SOCKET_ERROR) 
		::closesocket(tcp2);
	
	if (tcp1 != SOCKET_ERROR) 
		::closesocket(tcp1);
	
	return SOCKET_ERROR;
}

void EventBase::Init()
{
	WSADATA wsa_data;
	if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0)
		return;

	if (CreatePipe(m_wakeupFds) == SOCKET_ERROR)
	{
		LCritical("CreatePipe error");
		return;
	}

	//创建唤醒通道,该通道用于唤醒I/O线程
	Channel *ch = new Channel(this, m_wakeupFds[0]);
	LInfo("wakeup fd:{}", m_wakeupFds[0]);

	ch->OnRead([ch, this]
	{
		char buf[1024];
		int r = (ch->fd() != SOCKET_ERROR) ? ::recv(ch->fd(), buf, sizeof(buf), 0) : 0;
		if (r > 0)
		{
			Task task;
			LDebug("EventBase addr: {}, m_tasks size: {}", (int32_t)this, m_tasks.size());
			while (m_tasks.pop_wait(&task, 0)) 	//从任务队列取一个任务执行
			{
				task();
			}
		}
		else if (r == 0) 	//closesocket(m_wakeupFds)后
		{
			LDebug("close wakeup fd: {}, thread id:{}", ch->fd(), std::hash<std::thread::id>{}(std::this_thread::get_id()));
			delete ch;
		}
		//else if (errno == EINTR) {
		//}
		else
			LCritical("wakeup channel recv error, ret:{},code{}", r, errcode);
	});
}

void EventBase::CallIdles()
{
	int64_t now = util::timeMilli() / 1000;
	//for (auto &l : m_idleConns)
	//{
	//	int idle = l.first;
	//	auto lst = l.second;
	//	while (lst.size())
	//	{
	//		IdleNode &node = lst.front();
	//		if (node.m_updated + idle > now)
	//			break;

	//		node.m_updated = now;
	//		lst.splice(lst.end(), lst, lst.begin());
	//		node.m_cb(node.m_conn);
	//	}
	//}
	//FIXME: 检查迭代器
	auto it = m_idleConns.begin();
	while (it != m_idleConns.end())
	{
		auto l = *it;
		++it;
		int idle = l.first;
		auto lst = l.second;
		size_t cnt = lst.size();

		while (cnt--)
		{
			IdleNode &node = lst.front();
			if (node.m_updated + idle > now)
				break;

			node.m_updated = now;
			lst.splice(lst.end(), lst, lst.begin());
			node.m_cb(node.m_conn);
		}
	}
}

IdleId EventBase::RegisterIdle(int idle, const TcpConnPtr & con, const TcpCallBack & cb)
{
	if (!m_idleEnabled) 
	{
		RunAfter(1000, [this] { CallIdles(); }, 1000);
		m_idleEnabled = true;
	}
	auto &lst = m_idleConns[idle];
	lst.push_back(IdleNode{ con, util::timeMilli() / 1000, move(cb) });
	LDebug("RegisterIdle() m_lst size: {}", lst.size());
	return IdleId(new IdleIdImp(&lst, --lst.end()));
}

void EventBase::UnregisterIdle(const IdleId & id)
{
	//m_lst是m_idleConns中的链表,当m_idleConns clear后,链表也清空了, 所有IDleId引用的链表和链表迭代器都失效了
	id->m_lst->erase(id->m_iter);
}

void EventBase::UpdateIdle(const IdleId & id)
{
	id->m_iter->m_updated = util::timeMilli() / 1000;
	id->m_lst->splice(id->m_lst->end(), *id->m_lst, id->m_iter);	//将m_iter拼接到链表尾部
}

void EventBase::HandleTimeouts()
{
	int64_t now = util::timeMilli();
	TimerId tid{ now, 1LL << 62 };
	while (m_timers.size() && m_timers.begin()->first < tid)
	{
		Task task = move(m_timers.begin()->second);
		m_timers.erase(m_timers.begin());
		task();
	}
	RefreshNearest();
}

void EventBase::RefreshNearest(const TimerId * tid)
{
	if (m_timers.empty())
		m_nextTimeout = 1 << 30;
	else
	{
		const TimerId &t = m_timers.begin()->first;
		m_nextTimeout = t.first - util::timeMilli();
		m_nextTimeout = m_nextTimeout < 0 ? 0 : m_nextTimeout;
	}
}

void EventBase::RepeatableTimeout(TimerRepeatable * tr)
{
	tr->m_at += tr->m_interval;
	tr->m_timerid = { tr->m_at, ++m_timerSeq };
	m_timers[tr->m_timerid] = [this, tr] { RepeatableTimeout(tr); };
	RefreshNearest(&tr->m_timerid);
	tr->m_cb();
}

void EventBase::LoopOnce(int waitMs)
{
	m_poller->LoopOnce((std::min)(waitMs, m_nextTimeout));
	HandleTimeouts();
}

bool EventBase::Cancel(TimerId timerid)
{
	if (timerid.first < 0)	//重复任务
	{
		//
		auto p = m_timerReps.find(timerid);
		if (p != m_timerReps.end())
		{
			auto ptimer = m_timers.find(p->second.m_timerid);
			if (ptimer != m_timers.end())
				m_timers.erase(ptimer);
			m_timerReps.erase(p);
			return true;
		}
	}
	else	//一次性任务
	{
		auto p = m_timers.find(timerid);
		if (p != m_timers.end())
		{
			m_timers.erase(p);
			return true;
		}
	}
	return false;
}


void EventBase::Loop()
{
	while (!m_exit)
		LoopOnce(3000);

	m_timerReps.clear();
	m_timers.clear();

	/*
	 * TcpConn::m_idleIds::m_lst是m_idleConns中的链表
	 * 当m_idleConns 调用clear()后,链表清空了,所有IDleId引用的链表和链表迭代器都失效了,导致list崩溃
	 * 为了方便维护,此处注释掉clear
	*/
	//m_idleConns.clear();

	m_conns.Clear();

	LoopOnce(0);
}

EventBase & EventBase::Exit()
{
	m_exit = true;
	return *this;
}

bool EventBase::Exited()
{
	return m_exit;
}

void EventBase::Wakeup()
{
	int ret = ::send(m_wakeupFds[1], "", 1, 0);
	if (ret < 0)
		LCritical("Wakeup send error code:{}", errcode);
}

void EventBase::SafeCall(Task && task)
{
	m_tasks.push(move(task));
	//LDebug("m_tasks size: {}", m_tasks.size());
	Wakeup();
}

TimerId EventBase::RunAt(int64_t milli, Task && task, int64_t interval)
{
	if (m_exit)
	{
		return TimerId();
	}
	if (interval)
	{
		TimerId tid{ -milli, ++m_timerSeq };
		TimerRepeatable &rtr = m_timerReps[tid];
		rtr = { milli, interval,{ milli, ++m_timerSeq }, move(task) };
		TimerRepeatable *tr = &rtr;
		m_timers[tr->m_timerid] = [this, tr] { RepeatableTimeout(tr); };
		RefreshNearest(&tr->m_timerid);
		return tid;
	}
	else
	{
		TimerId tid{ milli, ++m_timerSeq };
		m_timers.insert({ tid, move(task) });
		RefreshNearest(&tid);
		return tid;
	}
}

EventBase * MultiBase::AllocBase()
{
	int c = m_id++;
	return &m_bases[c % m_bases.size()];
}

void MultiBase::Loop()
{
	int sz = m_bases.size();
	vector<thread> ths(sz - 1);
	for (int i = 0; i < sz - 1; i++)
	{
		thread t([this, i] { m_bases[i].Loop(); });
		ths[i].swap(t);
	}
	m_bases.back().Loop();
	for (int i = 0; i < sz - 1; i++)
		ths[i].join();
}

MultiBase & MultiBase::Exit()
{
	for (auto &b : m_bases)
		b.Exit();
	return *this;
}


Channel::Channel(EventBase* base, int fd) :
	m_base(base), m_fd(fd),
	m_readable(true), m_writable(false)
{
	static atomic<int64_t> id(0);
	m_id = ++id;
	m_base->m_poller->AddChannel(this);
}

Channel::~Channel()
{
	LDebug("~Channel(), fd:{}", m_fd);
	Close();
}

void Channel::Close()
{
	if (m_fd != SOCKET_ERROR)
	{
		m_base->m_poller->RemoveChannel(this);
		::shutdown(m_fd, SD_SEND);
		::closesocket(m_fd);
		m_fd = SOCKET_ERROR;
		HandleRead();	
	}
}

void Channel::EnableRead(bool enable)
{
	enable == true ? m_readable = true : m_readable = false;
	m_base->m_poller->UpdateChannel(this);
}

void Channel::EnableWrite(bool enable)
{
	enable == true ? m_writable = true : m_writable = false;
	m_base->m_poller->UpdateChannel(this);
}

 

 类似资料: