现在的位置: 首页 > 综合 > 正文

vc++ socket实现的支持断点续传的下载器

2012年11月18日 ⁄ 综合 ⁄ 共 24814字 ⁄ 字号 评论关闭

网上找了一堆代码,有用wininet的,还有用socket的,整理了半天,还是觉得socket靠谱。

只支持内存中断点续传。如果要加上在磁盘上断点续传,原理也差不多,不是本文重点。

注释:

1. CByteBufferVector是一个缓存池,动态分配BYTE形数组空间用的。代码略,可以简单看成BYTE数组。

2. GetStringA是一个CString转CStringA的函数,无需多说。

3. 除了win socket基本没有其它依赖,噢对,ATL::CString除外……

头文件:

  1. class CSocketDownloader;  
  2.   
  3.   
  4. /** 
  5.  *<SPAN style="WHITE-SPACE: pre">  </SPAN>下载任务 
  6.  */  
  7. class CDownloadTask  
  8. {  
  9.     friend class CSocketDownloader;  
  10.   
  11.   
  12. public:  
  13.   
  14.   
  15.     CDownloadTask();  
  16.   
  17.   
  18.     CStringA GetUrlA() const;  
  19.     CStringA GetAgnetA() const;  
  20.     void ParseUrl();  
  21.       
  22.     int Percentage() const;  
  23.     DWORD RemainTimeSec(DWORD dwTickElapsed, unsigned int uBytesTransferred) const;  
  24.   
  25.   
  26.     CString         m_strUrl;           // 下载地址
      
  27.     CString         m_strAgent;         // 用户agent
      
  28.     int             m_nMaxTryCount;     // 最多重试次数(重定向不算重试,默认20次)
      
  29.     int             m_nTimeoutSec;      // socket超时(秒,默认10秒)
      
  30.     int             m_nPort;            // 端口(默认80)
      
  31.     HWND            m_hWnd;             // 接收下载进度消息的窗口句柄
      
  32.     LONG            *m_pTerminate;      // 指向是否中止的标志位,一般由用户界面操作(如点击“取消”按钮)更改此值
      
  33.   
  34.   
  35. protected:  
  36.   
  37.   
  38.     CStringA        m_strAbsoluteUrlA;  
  39.     CStringA        m_strQueryA;  
  40.     CStringA        m_strHostA;  
  41.     unsigned int    m_uReadBytes;  
  42.     unsigned int    m_uTotalBytes;  
  43. };  
  44.   
  45.   
  46. /** 
  47.  *<SPAN style="WHITE-SPACE: pre">  </SPAN>socket实现的断点续传下载器 
  48.  */  
  49. class CSocketDownloader  
  50. {  
  51. public:  
  52.     CSocketDownloader();  
  53.     virtual ~CSocketDownloader();  
  54.   
  55.   
  56.     // 下载到一个buffer   
  57.     DWORD DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);  
  58.   
  59.   
  60.     // 下载到一个文件   
  61.     DWORD DownloadToFile(CDownloadTask &task, CString strOutputFile);  
  62.   
  63.   
  64. protected:  
  65.   
  66.   
  67.     DWORD DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);  
  68.     DWORD ConnectServer(const CDownloadTask &task, SOCKET hSocket);  
  69.     DWORD DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket);  
  70.       
  71.     int GetSleepSecCount(int nTryCount) const;  
  72.     int GetBufferSize(const CDownloadTask &task) const;  
  73.     CStringA GenerateRequest(CDownloadTask &task) const;  
  74.   
  75.   
  76. };  

CPP文件:

  1. <SPAN style="FONT-FAMILY: Arial, Verdana, sans-serif"><SPAN style="WHITE-SPACE: normal"><SPAN style="FONT-FAMILY: monospace"><SPAN style="WHITE-SPACE: pre">#include <math.h>  
  2. #include <time.h>   
  3.   
  4.   
  5. const int BLOCK_SIZE = 1024 * 64;  
  6. const int DEFAULT_MAX_TRY = 20;  
  7. const int DEFAULT_TIMEOUT = 10;  
  8.   
  9.   
  10. //////////////////////////////////////////////////////////////////////////
      
  11. // 下载任务   
  12. //////////////////////////////////////////////////////////////////////////
      
  13.   
  14.   
  15. CDownloadTask::CDownloadTask()  
  16.     : m_nPort(INTERNET_DEFAULT_HTTP_PORT),  
  17.     m_nMaxTryCount(DEFAULT_MAX_TRY),  
  18.     m_uReadBytes(0),  
  19.     m_uTotalBytes(0),  
  20.     m_nTimeoutSec(DEFAULT_TIMEOUT),  
  21.     m_hWnd(NULL),  
  22.     m_pTerminate(NULL)  
  23. {  
  24.   
  25.   
  26. }  
  27.   
  28.   
  29. CStringA CDownloadTask::GetUrlA() const  
  30. {  
  31.     return GetStringA(m_strUrl);  
  32. }  
  33.   
  34.   
  35. CStringA CDownloadTask::GetAgnetA() const  
  36. {  
  37.     return GetStringA(m_strAgent);  
  38. }  
  39.   
  40.   
  41. void CDownloadTask::ParseUrl()  
  42. {  
  43.     m_strAbsoluteUrlA = m_strHostA = m_strQueryA = "";  
  44.   
  45.   
  46.     CStringA strUrlA = this->GetUrlA();  
  47.     const char *pUrl = strUrlA;  
  48.     const char *p = pUrl;  
  49.     const char *szHttpHead = "http://";  
  50.   
  51.   
  52.     if (_strnicmp(pUrl, szHttpHead, strlen(szHttpHead)) == 0)  
  53.     {  
  54.         p = pUrl + strlen(szHttpHead);  
  55.     }  
  56.   
  57.   
  58.     int nHostLen = 0;  
  59.     const char *q = strchr(p, '/');  
  60.     if (q != NULL)  
  61.     {  
  62.         nHostLen = q - p;  
  63.         int nPathLen = 0;  
  64.         const char *r = strchr(q, '?');  
  65.         if (r != NULL)  
  66.         {  
  67.             // 解析query
      
  68.             r++;  
  69.             m_strQueryA = r;  
  70.             nPathLen = r - q - 1;  
  71.         }  
  72.         else  
  73.         {  
  74.             nPathLen = strlen(q);  
  75.         }  
  76.   
  77.   
  78.         // 解析abs_path   
  79.         m_strAbsoluteUrlA.Append(q, nPathLen);  
  80.     }  
  81.     else  
  82.     {  
  83.         nHostLen = strlen(p);  
  84.     }  
  85.   
  86.   
  87.     // 解析host   
  88.     m_strHostA.Append(p, nHostLen);  
  89.   
  90.   
  91.     // 解析port   
  92.     const char *r = strchr(m_strHostA, ':');  
  93.     if (r == 0)  
  94.     {  
  95.         m_nPort = INTERNET_DEFAULT_HTTP_PORT;  
  96.     }  
  97.     else  
  98.     {  
  99.         m_nPort = atoi(r + 1);  
  100.     }  
  101. }  
  102.   
  103.   
  104. int CDownloadTask::Percentage() const  
  105. {  
  106.     return (m_uTotalBytes == 0)  
  107.         ? 0  
  108.         : (int)((unsigned long long)m_uReadBytes * 100 / (unsigned long long) m_uTotalBytes);  
  109. }  
  110.   
  111.   
  112. DWORD CDownloadTask::RemainTimeSec( DWORD dwTickElapsed, unsigned int uBytesTransferred ) const  
  113. {  
  114.     unsigned long long uTickElapsed = (unsigned long long)dwTickElapsed;  
  115.     unsigned long long uBytes = (unsigned long long)uBytesTransferred;  
  116.     unsigned long long uRemain = (unsigned long long)(m_uTotalBytes - m_uReadBytes);  
  117.     Log(_T("elapsed=%d, get=%d, remain=%d\n"), dwTickElapsed, uBytesTransferred, m_uTotalBytes - m_uReadBytes);  
  118.     return (DWORD)(uTickElapsed * uRemain / (uBytes * CLOCKS_PER_SEC));  
  119. }  
  120.   
  121.   
  122. //////////////////////////////////////////////////////////////////////////
      
  123. // socket下载器   
  124. //////////////////////////////////////////////////////////////////////////
      
  125.   
  126.   
  127. CSocketDownloader::CSocketDownloader()  
  128. {  
  129.   
  130.   
  131. }  
  132.   
  133.   
  134. CSocketDownloader::~CSocketDownloader()  
  135. {  
  136.   
  137.   
  138. }  
  139.   
  140.   
  141. DWORD CSocketDownloader::DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)  
  142. {  
  143.     int nTryCount = 0;  
  144.     DWORD dwRet = this->DoDownloadToBuffer(task, bufVec);  
  145.     if (web::THE_REDIRECT != dwRet)  
  146.     {  
  147.         nTryCount++;  
  148.     }  
  149.   
  150.   
  151.     while (  
  152.         dwRet != web::THE_SUCCEED  
  153.         && dwRet != web::THE_USER_CANCELED  
  154.         && nTryCount < task.m_nMaxTryCount  
  155.         )  
  156.     {  
  157.         int nTime = this->GetSleepSecCount(nTryCount);  
  158.         ::Sleep(nTime);  
  159.         dwRet = this->DoDownloadToBuffer(task, bufVec);  
  160.         if (web::THE_REDIRECT != dwRet)  
  161.         {  
  162.             nTryCount++;  
  163.         }  
  164.     }  
  165.   
  166.   
  167.     return dwRet;  
  168. }  
  169.   
  170.   
  171. DWORD CSocketDownloader::DownloadToFile( CDownloadTask &task, CString strOutputFile )  
  172. {  
  173.     CByteBufferVector vec;  
  174.   
  175.   
  176.     DWORD dwRet = this->DownloadToBuffer(task, vec);  
  177.     if (web::THE_SUCCEED != dwRet)  
  178.     {  
  179.         return dwRet;  
  180.     }  
  181.   
  182.   
  183.     HANDLE hFile = ::CreateFile(strOutputFile, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);  
  184.     if (hFile == INVALID_HANDLE_VALUE)  
  185.     {  
  186.         return web::THE_CREATE_FILE;  
  187.     }  
  188.   
  189.   
  190.     BYTE *pBuffer = vec.Ptr(0, task.m_uTotalBytes);  
  191.     DWORD dwBytesWritten = 0;  
  192.     ::WriteFile(hFile, pBuffer, task.m_uTotalBytes, &dwBytesWritten, NULL);  
  193.     ::CloseHandle(hFile);  
  194.     return (dwBytesWritten == task.m_uTotalBytes) ? web::THE_SUCCEED : web::THE_WRITE_FILE;  
  195. }  
  196.   
  197.   
  198. DWORD CSocketDownloader::DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)  
  199. {  
  200.     task.ParseUrl();  
  201.   
  202.   
  203.     SOCKET hSocket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);  
  204.     if (hSocket == INVALID_SOCKET)  
  205.     {  
  206.         return web::THE_CREATE_SOCKET;  
  207.     }  
  208.   
  209.   
  210.     DWORD dwRet = this->ConnectServer(task, hSocket);  
  211.     if (web::THE_SUCCEED != dwRet)  
  212.     {  
  213.         closesocket(hSocket);  
  214.         return dwRet;  
  215.     }  
  216.   
  217.   
  218.     dwRet =  this->DoDownloadToBufferInner(task, bufVec, hSocket);  
  219.     closesocket(hSocket);  
  220.     return dwRet;  
  221. }  
  222.   
  223.   
  224. DWORD CSocketDownloader::DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket)  
  225. {  
  226.     // 发送请求   
  227.     CStringA strRequest = this->GenerateRequest(task);  
  228.     int nLen = send(hSocket, strRequest, strRequest.GetLength(), 0);  
  229.     if (nLen <= 0)  
  230.     {  
  231.         return web::THE_SEND_HTTP_HEADER;  
  232.     }  
  233.   
  234.   
  235.     // 接收一部分数据(header部分,以"\r\n\r\n"为止)
      
  236.     CStringA strRecvBuf;  
  237.     char szRecvBuf[MAX_PATH] = { 0 };  
  238.     nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);  
  239.     while (nLen > 0)  
  240.     {  
  241.         szRecvBuf[nLen] = 0;  
  242.         strRecvBuf.Append(szRecvBuf);  
  243.         if (strstr(szRecvBuf, "\r\n\r\n") != NULL)  
  244.         {  
  245.             break;  
  246.         }  
  247.         nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);  
  248.     }  
  249.       
  250.     // 找到两个回车换行,即content起始位置。   
  251.     const char *pData = strstr(szRecvBuf, "\r\n\r\n");  
  252.     if (pData == NULL)  
  253.     {  
  254.         return web::THE_INVALID_RECV_END;  
  255.     }  
  256.   
  257.   
  258.     pData += 4;  
  259.   
  260.   
  261.     const char *p = strchr(strRecvBuf, ' ');  
  262.     if (p != NULL)  
  263.     {  
  264.         p++;  
  265.         DWORD dwRet = atoi(p);  
  266.   
  267.   
  268.         if (dwRet == HTTP_STATUS_PARTIAL_CONTENT)        // 206: 断点续传
      
  269.         {  
  270.             const char *q = strstr(strRecvBuf, "\r\nContent-Length:");  
  271.             if (q == NULL)  
  272.             {  
  273.                 return web::THE_NO_CONTENT_LENGTH;  
  274.             }  
  275.             task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);  
  276.         }  
  277.         else if (dwRet == HTTP_STATUS_OK)               // 200: 重新下载(服务器不支持断点续传)
      
  278.         {  
  279.             const char *q = strstr(strRecvBuf, "\r\nContent-Length:");  
  280.             if (q == NULL)  
  281.             {  
  282.                 return web::THE_NO_CONTENT_LENGTH;  
  283.             }  
  284.             task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);  
  285.             // 清除已经下载的内容
      
  286.             task.m_uReadBytes = 0;  
  287.             bufVec.Reset();  
  288.         }  
  289.         else if (dwRet == HTTP_STATUS_REDIRECT)         // 302: 重定向
      
  290.         {  
  291.             const char *q = strstr(strRecvBuf, "\r\nLocation:");  
  292.             if (q == NULL)  
  293.             {  
  294.                 return web::THE_NO_REDIRECT_LOCATION;  
  295.             }  
  296.             q += 12;  
  297.             const char *r = strstr(q, "\r\n");  
  298.             if (r == NULL)  
  299.             {  
  300.                 return web::THE_REDIRECT_INVALID_FORMAT;  
  301.             }  
  302.   
  303.   
  304.             int nUrlLen = r - q;  
  305.             CStringA strUrlA;  
  306.             strUrlA.Append(q, nUrlLen);  
  307.             task.m_strUrl = GetString(strUrlA);  
  308.             return web::THE_REDIRECT;  
  309.         }  
  310.         else  
  311.         {  
  312.             return web::THE_INVALID_STAUS_CODE;  
  313.         }  
  314.     }  
  315.       
  316.     // 复制已传回来的第一部分content   
  317.     int nSize = nLen - (pData - szRecvBuf);  
  318.     BYTE *pBuffer = bufVec.Ptr(task.m_uReadBytes, nSize);  
  319.     memcpy(pBuffer, pData, nSize);  
  320.     task.m_uReadBytes += nSize;  
  321.   
  322.   
  323.     // 继续接收http content,即下载内容。
      
  324.     int nBufferSize = this->GetBufferSize(task);  
  325.     pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);  
  326.   
  327.   
  328.     DWORD dwLastTick = 0;  
  329.   
  330.   
  331.     // 下载测速   
  332.     DWORD dwTickStart = ::GetTickCount();  
  333.     unsigned int uReadBytesStart = task.m_uReadBytes;  
  334.   
  335.   
  336.     while (true)  
  337.     {  
  338.         if (::InterlockedCompareExchange(task.m_pTerminate, 1, 1))  
  339.         {  
  340.             // 用户取消。   
  341.             return web::THE_USER_CANCELED;  
  342.         }  
  343.   
  344.   
  345.         nLen = recv(hSocket, (char *)(pBuffer), nBufferSize, 0);  
  346.         if (nLen < 0)  
  347.         {  
  348.             return web::THE_RECV_FAIL;  
  349.         }  
  350.         else if (nLen == 0)  
  351.         {  
  352.             break;  // 接收完成
      
  353.         }  
  354.   
  355.   
  356.         task.m_uReadBytes += nLen;  
  357.         if (task.m_uReadBytes == task.m_uTotalBytes)  
  358.         {  
  359.             break;  // 接收完成
      
  360.         }  
  361.   
  362.   
  363.         nBufferSize = this->GetBufferSize(task);  
  364.         pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);  
  365.   
  366.   
  367.         if (NULL != task.m_hWnd)  
  368.         {  
  369.             DWORD dwTick = ::GetTickCount();  
  370.             if (dwLastTick == 0 || (dwTick - dwLastTick >= 100))        // 每秒最多发10次消息
      
  371.             {  
  372.                 // 发送当前下载进度和剩余时间消息
      
  373.                 dwLastTick = dwTick;  
  374.                 ::PostMessage(task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,  
  375.                     static_cast<WPARAM>(task.Percentage()), static_cast<LPARAM>(task.RemainTimeSec(dwTick - dwTickStart, task.m_uReadBytes - uReadBytesStart))  
  376.                     );  
  377.             }  
  378.         }  
  379.     }  
  380.   
  381.   
  382.     DWORD dwTick = ::GetTickCount();  
  383.     if (NULL != task.m_hWnd)  
  384.     {  
  385.         ::PostMessage(  
  386.             task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,  
  387.             static_cast<WPARAM>(task.Percentage()), -1  
  388.             );  
  389.     }  
  390.   
  391.   
  392.     return web::THE_SUCCEED;  
  393. }  
  394.   
  395.   
  396. DWORD CSocketDownloader::ConnectServer(const CDownloadTask &task, SOCKET hSocket)  
  397. {  
  398.     PHOSTENT pHostent = gethostbyname(task.m_strHostA);  
  399.     if (pHostent == NULL)  
  400.     {  
  401.         return web::THE_GET_HOST_BY_NAME;  
  402.     }  
  403.   
  404.   
  405.     sockaddr_in addrSvr;  
  406.     addrSvr.sin_port = htons((u_short)task.m_nPort);  
  407.     addrSvr.sin_family = AF_INET;  
  408.     addrSvr.sin_addr.s_addr = *(ULONG*)pHostent->h_addr_list[0];  
  409.     if (SOCKET_ERROR == connect(hSocket, (sockaddr*)&addrSvr, sizeof(addrSvr)))  
  410.     {  
  411.         return web::THE_CONNECT_SOCKET;  
  412.     }  
  413.   
  414.   
  415.     int opt = task.m_nTimeoutSec * 1000;  
  416.     if (0 != setsockopt(hSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&opt, sizeof(opt)))  
  417.     {  
  418.         return web::THE_SET_SOCK_OPT1;  
  419.     }  
  420.   
  421.   
  422.     BOOL bKeepAlive = TRUE;    
  423.     int len = sizeof(bKeepAlive);  
  424.     getsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char*)&bKeepAlive, &len);  
  425.   
  426.   
  427.     bKeepAlive = TRUE;  
  428.     if (0 != setsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char *)&bKeepAlive, sizeof(BOOL)))  
  429.     {  
  430.         return web::THE_SET_SOCK_OPT2;  
  431.     }  
  432.   
  433.   
  434.     return web::THE_SUCCEED;  
  435. }  
  436.   
  437.   
  438. int CSocketDownloader::GetSleepSecCount( int nTryCount ) const  
  439. {  
  440.     return (nTryCount + 1) * 1000;  
  441. }  
  442.   
  443.   
  444. int CSocketDownloader::GetBufferSize( const CDownloadTask &task ) const  
  445. {  
  446.     return std::min<int>(BLOCK_SIZE, task.m_uTotalBytes - task.m_uReadBytes);  
  447. }  
  448.   
  449.   
  450. CStringA CSocketDownloader::GenerateRequest( CDownloadTask &task ) const  
  451. {  
  452.     CStringA strRequest;  
  453.   
  454.   
  455.     CStringA strTemp;  
  456.     if (task.m_strQueryA.IsEmpty())  
  457.     {  
  458.         strTemp.Format(  
  459.             "GET %s HTTP/1.1\r\nHOST: %s\r\n",  
  460.             task.m_strAbsoluteUrlA.GetString(), task.m_strHostA.GetString()  
  461.             );  
  462.     }  
  463.     else  
  464.     {  
  465.         strTemp.Format(  
  466.             "GET %s?%s HTTP/1.1\r\nHOST: %s\r\n",  
  467.             task.m_strAbsoluteUrlA.GetString(), task.m_strQueryA.GetString(), task.m_strHostA.GetString()  
  468.             );  
  469.     }  
  470.     strRequest.Append(strTemp);  
  471.   
  472.   
  473.     strTemp.Format("Range: bytes=%d-\r\n", task.m_uReadBytes);  
  474.     strRequest.Append(strTemp);  
  475.   
  476.   
  477.     strTemp.Format("User-Agent: %s\r\n", task.GetAgnetA().GetString());  
  478.     strRequest.Append(strTemp);  
  479.   
  480.   
  481.     strRequest.Append("Accept: */*\r\n");  
  482.     strRequest.Append("Accept-Encoding: gzip, deflate\r\n");  
  483.     strRequest.Append("Connection: Keep-Alive\r\n\r\n");  
  484.       
  485.     return strRequest;  
  486. }  
  487. </SPAN></SPAN></SPAN></SPAN>  

抱歉!评论已关闭.