LSP,SPI动态链接库(VC)
//********** DEF
LIBRARY
EXPORTS
WSPStartup
//********** C++
#define UNICODE // 必须
#define _UNICODE // 必须
#include <WinSock2.h>
#include <WS2spi.h>
#include <Windows.h>
#include <tchar.h>
#pragma comment( lib, "WS2_32.LIB" )
// 全局变量 ,用来保存系统服务提供的30个服务函数指针
WSPPROC_TABLE g_NextProcTable;
// 全局变量,用来临时保存全部服务提供者的信息
LPWSAPROTOCOL_INFOW g_ProtocolInfo = NULL;
DWORD g_ProtocolInfoSize = 0;
INT g_nTotalProtocols = 0;
// ============================================= 入口函数 =====================================================
BOOL APIENTRY DllMain(
HANDLE hModule,
DWORD ul_reason_for_call,
LPVOID lpReserved )
{
return true;
}
// ============================================= SPI函数 =====================================================
BOOL GetProviders(
OUT LPWSAPROTOCOL_INFOW & pProtocolInfo,
OUT INT & nTotalProtocols )
{
DWORD nProtocolInfoSize = 0;
INT ErrorCode;
pProtocolInfo = NULL;
if( WSCEnumProtocols( NULL, pProtocolInfo, &nProtocolInfoSize, &ErrorCode ) == SOCKET_ERROR &&
ErrorCode != WSAENOBUFS )
return FALSE;
pProtocolInfo = reinterpret_cast<LPWSAPROTOCOL_INFOW>( GlobalAlloc( GPTR, nProtocolInfoSize ) );
if( pProtocolInfo == NULL ) return FALSE;
nTotalProtocols = WSCEnumProtocols( NULL, pProtocolInfo, &nProtocolInfoSize, &ErrorCode );
if( nTotalProtocols == SOCKET_ERROR && pProtocolInfo != NULL )
{
GlobalFree( pProtocolInfo );
pProtocolInfo = NULL;
}
return TRUE;
}
void FreeProviders(
IN LPWSAPROTOCOL_INFOW pProtocolInfo )
{
GlobalFree( pProtocolInfo );
}
BOOL GetHookProvider(
IN LPWSAPROTOCOL_INFOW pProtocolInfo,
OUT LPWSTR szPathName,
IN DWORD cchPathName,
OUT WSAPROTOCOL_INFOW & NextProtocolInfo )
{
if( pProtocolInfo->ProtocolChain.ChainLen <= 1 ) return FALSE;
BOOL bResult = FALSE;
LPWSAPROTOCOL_INFOW pProtoInfo = NULL;
INT nTotalProtocols = 0;
if( !GetProviders( pProtoInfo, nTotalProtocols ) ) return FALSE;
for( int i = pProtocolInfo->ProtocolChain.ChainLen - 1; i > 0; --i )
{
for( int j = 0; j < nTotalProtocols; ++j )
{
if( pProtocolInfo->ProtocolChain.ChainEntries[i] ==
pProtoInfo[j].dwCatalogEntryId )
{
WCHAR szProviderPath[MAX_PATH];
INT nErrorNo, nProviderPathLen = MAX_PATH;
if( WSCGetProviderPath(
&pProtoInfo[j].ProviderId,
szProviderPath,
&nProviderPathLen,
&nErrorNo ) == SOCKET_ERROR )
break;
if( !ExpandEnvironmentStringsW( szProviderPath, szProviderPath, MAX_PATH ) )
break;
CopyMemory(
&NextProtocolInfo,
pProtocolInfo,
sizeof(WSAPROTOCOL_INFOW) );
wcscpy_s( szPathName, cchPathName, szProviderPath );
bResult = TRUE;
break;
}
}
}
FreeProviders( pProtoInfo );
return bResult;
}
_Must_inspect_result_
SOCKET
WSPAPI WSPSocket(
_In_ int af,
_In_ int type,
_In_ int protocol,
_In_opt_ LPWSAPROTOCOL_INFOW lpProtocolInfo,
_In_ GROUP g,
_In_ DWORD dwFlags,
_Out_ LPINT lpErrno
)
{
OutputDebugString( _T("LSPDLL: WSPSocket\n") );
return g_NextProcTable.lpWSPSocket(
af,
type,
protocol,
lpProtocolInfo,
g,
dwFlags,
lpErrno );
}
_Must_inspect_result_
int
WSPAPI WSPStartup(
_In_ WORD wVersionRequested,
_In_ LPWSPDATA lpWSPData,
_In_ LPWSAPROTOCOL_INFOW lpProtocolInfo,
_In_ WSPUPCALLTABLE UpcallTable,
_Out_ LPWSPPROC_TABLE lpProcTable
)
{
OutputDebugString( _T("LSPDLL: WSPStartup\n") );
WCHAR szLibraryPath[MAX_PATH];
LPWSPSTARTUP lpfnWSPStartup = NULL;
HMODULE hLibraryHandle = NULL;
INT ErrorCode = 0;
WSAPROTOCOL_INFOW NextProtocolInfo;
if( !GetHookProvider( lpProtocolInfo, szLibraryPath, MAX_PATH, NextProtocolInfo ) ||
(hLibraryHandle = LoadLibrary(szLibraryPath)) == NULL ||
(lpfnWSPStartup = (LPWSPSTARTUP)GetProcAddress( hLibraryHandle, "WSPStartup")) == NULL )
return WSAEPROVIDERFAILEDINIT;
if( (ErrorCode = lpfnWSPStartup(
wVersionRequested,
lpWSPData,
&NextProtocolInfo,
UpcallTable,
lpProcTable )) != ERROR_SUCCESS )
return ErrorCode;
g_NextProcTable = *lpProcTable;
lpProcTable->lpWSPSocket = WSPSocket;
return 0;
}
#define UNICODE // 必须
#define _UNICODE // 必须
#include <WS2spi.h>
#include <SpOrder.h> // WSCWriteProviderOrder
#include <Windows.h>
#include <stdio.h>
#include <tchar.h>
#include <atlbase.h>
#pragma comment( lib, "ws2_32.lib" )
#pragma comment( lib, "Rpcrt4.lib" ) // UuidCreate
// LSP 硬编码 UUID
GUID g_ProviderGuid = { //6acd2327-3dea-2123-4123-003792ead212
0x6acd2327,
0x3dea,
0x2123,
{0x41, 0x23, 0x00, 0x37, 0x92, 0xea, 0xd2, 0x12}
};
GUID g_ProviderChainGuid = { //6acd2327-3dea-2123-4123-003792ead213
0x6acd2327,
0x3dea,
0x2123,
{0x41, 0x23, 0x00, 0x37, 0x92, 0xea, 0xd2, 0x13}
};
BOOL GetProviders(
OUT LPWSAPROTOCOL_INFOW & pProtocolInfo,
OUT INT & nTotalProtocols )
{
DWORD nProtocolInfoSize = 0;
INT ErrorCode;
pProtocolInfo = NULL;
if( WSCEnumProtocols( NULL, pProtocolInfo, &nProtocolInfoSize, &ErrorCode ) == SOCKET_ERROR &&
ErrorCode != WSAENOBUFS )
return FALSE;
pProtocolInfo = reinterpret_cast<LPWSAPROTOCOL_INFOW>( GlobalAlloc( GPTR, nProtocolInfoSize ) );
if( pProtocolInfo == NULL ) return FALSE;
nTotalProtocols = WSCEnumProtocols( NULL, pProtocolInfo, &nProtocolInfoSize, &ErrorCode );
if( nTotalProtocols == SOCKET_ERROR && pProtocolInfo != NULL )
{
GlobalFree( pProtocolInfo );
pProtocolInfo = NULL;
}
return TRUE;
}
void FreeProviders(
IN LPWSAPROTOCOL_INFOW pProtocolInfo )
{
GlobalFree( pProtocolInfo );
}
void SetProtocolChain(
IN WCHAR *sName,
IN OUT WSAPROTOCOL_INFOW &ProtocolInfo,
IN DWORD LayeredCatalogId,
IN DWORD NextCatalogId,
OUT WSAPROTOCOL_INFOW &OutProtocolInfo
)
{
WCHAR ChainName[WSAPROTOCOL_LEN+1];
swprintf_s(ChainName, WSAPROTOCOL_LEN+1, L"%s [%s]", sName, ProtocolInfo.szProtocol);
wcscpy_s(ProtocolInfo.szProtocol, WSAPROTOCOL_LEN+1, ChainName);
ProtocolInfo.ProtocolChain.ChainEntries[0] = LayeredCatalogId;
ProtocolInfo.ProtocolChain.ChainEntries[ProtocolInfo.ProtocolChain.ChainLen] = NextCatalogId;
ProtocolInfo.ProtocolChain.ChainLen++;
CopyMemory(&OutProtocolInfo, &ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
}
BOOL InstallProvider(
LPCWSTR lpPathName,
LPCWSTR lpLSPName = L"BeaconLSP" )
{
LPWSAPROTOCOL_INFOW pProtocolInfo;
INT nTotalProtocol;
WSAPROTOCOL_INFOW ChainInfoUdp, ChainInfoTcp, ChainInfoRaw;
DWORD OrigCatalogIdUdp, OrigCatalogIdTcp, OrigCatalogIdRaw;
INT nArrayCount = 0;
DWORD LayeredCatalogId;
INT ErrorCode;
// 枚举所有服务程序提供者
if( !GetProviders( pProtocolInfo, nTotalProtocol ) ) return FALSE;
BOOL bRawIp = FALSE, bUdpIp = FALSE, bTcpIp = FALSE;
for( INT i = 0; i < nTotalProtocol; ++i )
{
if( pProtocolInfo[i].iAddressFamily == AF_INET )
{
#define COPY_PROTO_INFO( ChainInfo, OrigCatalogId ) \
CopyMemory( &ChainInfo, &pProtocolInfo[i], sizeof(WSAPROTOCOL_INFOW) ); \
ChainInfo.dwServiceFlags1 = ChainInfo.dwServiceFlags1 & (~XP1_IFS_HANDLES); \
OrigCatalogId = pProtocolInfo[i].dwCatalogEntryId;
if( !bRawIp && pProtocolInfo[i].iProtocol == IPPROTO_IP )
{
COPY_PROTO_INFO( ChainInfoRaw, OrigCatalogIdRaw );
bRawIp = TRUE;
}
if( !bUdpIp && pProtocolInfo[i].iProtocol == IPPROTO_UDP )
{
COPY_PROTO_INFO( ChainInfoUdp, OrigCatalogIdUdp );
bUdpIp = TRUE;
}
if( !bTcpIp && pProtocolInfo[i].iProtocol == IPPROTO_TCP )
{
COPY_PROTO_INFO( ChainInfoTcp, OrigCatalogIdTcp );
bTcpIp = TRUE;
}
#undef COPY_PROTO_INFO
}
}
// 安装协议
WSAPROTOCOL_INFOW LayeredProtocolInfo;
CopyMemory( &LayeredProtocolInfo, &ChainInfoRaw, sizeof(WSAPROTOCOL_INFOW) );
if( lpLSPName ) wcscpy_s( LayeredProtocolInfo.szProtocol, WSAPROTOCOL_LEN+1, lpLSPName );
LayeredProtocolInfo.ProtocolChain.ChainLen = LAYERED_PROTOCOL;
// 安装
if( WSCInstallProvider(
&g_ProviderGuid,
lpPathName,
&LayeredProtocolInfo,
1,
&ErrorCode ) == SOCKET_ERROR )
{
return FALSE;
}
// 重新枚举协议,获取分层协议的目录ID号
FreeProviders( pProtocolInfo );
if( !GetProviders( pProtocolInfo, nTotalProtocol ) ) return FALSE;
for( INT i = 0; i < nTotalProtocol; ++i )
{
if( IsEqualGUID( pProtocolInfo[i].ProviderId, g_ProviderGuid ) )
{
LayeredCatalogId = pProtocolInfo[i].dwCatalogEntryId;
break;
}
}
INT ProvCnt = 0;
WSAPROTOCOL_INFOW ChainArray[3];
if( bRawIp )
{
SetProtocolChain(
L"LayeredCatalogId RAW/IP over",
ChainInfoRaw,
LayeredCatalogId,
OrigCatalogIdRaw,
ChainArray[ProvCnt++] );
}
if( bUdpIp )
{
SetProtocolChain(
L"LayeredCatalogId UDP/IP over",
ChainInfoUdp,
LayeredCatalogId,
OrigCatalogIdUdp,
ChainArray[ProvCnt++] );
}
if( bTcpIp )
{
SetProtocolChain(
L"LayeredCatalogId TCP/IP over",
ChainInfoTcp,
LayeredCatalogId,
OrigCatalogIdTcp,
ChainArray[ProvCnt++] );
}
if( WSCInstallProvider(
&g_ProviderChainGuid,
lpPathName,
ChainArray,
ProvCnt,
&ErrorCode ) == SOCKET_ERROR )
{
return FALSE;
}
FreeProviders( pProtocolInfo );
if( !GetProviders( pProtocolInfo, nTotalProtocol ) ) return FALSE;
LPDWORD CatalogEntries = (LPDWORD)GlobalAlloc( GPTR, nTotalProtocol * sizeof(DWORD) );
if( CatalogEntries == NULL )
{
FreeProviders( pProtocolInfo );
return FALSE;
}
INT CatIndex = 0;
for( INT i = 0; i < nTotalProtocol; ++i )
{
if( IsEqualGUID( pProtocolInfo[i].ProviderId, g_ProviderGuid ) ||
IsEqualGUID( pProtocolInfo[i].ProviderId, g_ProviderChainGuid ) )
{
CatalogEntries[CatIndex++] = pProtocolInfo[i].dwCatalogEntryId;
}
}
for( INT i = 0; i < nTotalProtocol; ++i )
{
if( !IsEqualGUID( pProtocolInfo[i].ProviderId, g_ProviderGuid ) &&
!IsEqualGUID( pProtocolInfo[i].ProviderId, g_ProviderChainGuid ) )
{
CatalogEntries[CatIndex++] = pProtocolInfo[i].dwCatalogEntryId;
}
}
ErrorCode = WSCWriteProviderOrder( CatalogEntries, nTotalProtocol );
if( ErrorCode != ERROR_SUCCESS )
{
FreeProviders( pProtocolInfo );
return FALSE;
}
FreeProviders( pProtocolInfo );
return TRUE;
}
BOOL UninstallProvider()
{
INT ErrorCode;
if( WSCDeinstallProvider( &g_ProviderGuid, &ErrorCode ) == SOCKET_ERROR )
return FALSE;
if( WSCDeinstallProvider( &g_ProviderChainGuid, &ErrorCode ) == SOCKET_ERROR )
return FALSE;
return TRUE;
}
#define LSP_DLL_PATH L"C:\\SPL.dll"
int _tmain( int argc, TCHAR * argv[] )
{
BOOL bResult = InstallProvider( LSP_DLL_PATH );
getchar();
bResult = UninstallProvider();
return 0;
}