python 多线程任务(下载)水平扩展线程

方案架构说明

这个方案由两个主要组件构成:


  1. 任务分发器 (TaskDistributor)
    • mysql 数据库中获取待下载的文件

    • 将下载任务放入 Redis 队列

    • 标记文件状态为 "处理中"

  2. 工作节点 (DownloadWorker)
    • 从 Redis 队列获取下载任务

    • 执行实际的文件下载

    • 更新数据库中的文件状态

import requests
import pymysql
import os
import time
import logging
import json
import redis
import argparse
from urllib.parse import urlparse
from requests.exceptions import RequestException
from datetime import datetime

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='downloader.log'
)
logger = logging.getLogger(__name__)

class DownloadWorker:
    def __init__(self, db_config, redis_config, download_dir='downloads'):
        """初始化下载工作节点"""
        self.db_config = db_config
        self.redis_config = redis_config
        self.download_dir = download_dir
        self.running = True
        
        self.redis_clIEnt = redis.Redis(**redis_config)
        self.queue_name = "download_queue"
        self.exit_flag_key = "download_exit_flag"
        
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'Mozilla/5.0 (windows NT 10.0; Win64; x64) APpleWebKit/537.36 (Khtml, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        })
        
        # 创建下载目录
        os.makedirs(download_dir, exist_ok=True)
        
        # 确保数据表存在
        self._create_table()
    
    def _get_db_connection(self):
        """获取数据库连接"""
        return pymysql.connect(**self.db_config)
    
    def _create_table(self):
        """创建数据表"""
        with self._get_db_connection() as conn:
            with conn.cursor() as cursor:
                cursor.execute('''
                    CREATE TABLE IF NOT EXISTS files (
                        id VARCHAR(255) PRIMARY KEY,
                        url TEXT NOT NULL,
                        status TINYINT NOT NULL DEFAULT 0,
                        file_path TEXT,
                        download_time FLOAT,
                        error_message TEXT,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
                    ) ENGINE=InnoDB DEFAULT CHARSET=UTF8mb4;
                ''')
                # 创建索引以加速查询
                cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON files (status)')
            conn.commit()
    
    def download_file(self, file):
        """下载单个文件"""
        file_id = file['id']
        url = file['url']
        start_time = time.time()
        
        try:
            # 生成保存路径
            parsed_url = urlparse(url)
            filename = os.path.basename(parsed_url.path)
            if not filename:
                filename = f"file_{file_id}"
            file_path = os.path.join(self.download_dir, filename)
            
            # 检查文件是否已存在
            if os.path.exists(file_path):
                self._update_status(file_id, 1, file_path, time.time() - start_time)
                return file_id, True, f"文件已存在: {file_path}"
            
            # 下载文件
            response = self.session.get(url, stream=True, timeout=30)
            response.raise_for_status()
            
            # 保存文件
            with open(file_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
            
            download_time = time.time() - start_time
            self._update_status(file_id, 1, file_path, download_time)
            logger.info(f"下载成功 (ID: {file_id}): {url} ({download_time:.2f}s)")
            return file_id, True, f"下载成功 ({download_time:.2f}s)"
            
        except RequestException as e:
            error_msg = f"下载失败: {str(e)}"
            self._update_status(file_id, -1, None, time.time() - start_time, error_msg)
            logger.error(f"下载失败 (ID: {file_id}): {error_msg}")
            return file_id, False, error_msg
        except Exception as e:
            error_msg = f"未知错误: {str(e)}"
            self._update_status(file_id, -1, None, time.time() - start_time, error_msg)
            logger.error(f"未知错误 (ID: {file_id}): {error_msg}")
            return file_id, False, error_msg
    
    def _update_status(self, file_id, status, file_path=None, download_time=None, error_message=None):
        """更新文件下载状态"""
        with self._get_db_connection() as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    '''UPDATE files 
                       SET status = %s, file_path = %s, download_time = %s, error_message = %s, updated_at = CURRENT_TIMESTAMP
                       WHERE id = %s''',
                    (status, file_path, download_time, error_message, file_id)
                )
            conn.commit()
    
    def process_task(self):
        """从队列获取并处理一个任务"""
        # 检查退出标志
        if self.redis_client.get(self.exit_flag_key):
            logger.info("检测到退出标志,准备退出...")
            self.running = False
            return False
        
        # 使用阻塞方式从队列获取任务,超时时间设为1秒
        _, task_json = self.redis_client.blpop(self.queue_name, timeout=1)
        
        if task_json:
            try:
                task = json.loads(task_json)
                file_id, success, message = self.download_file(task)
                
                # 记录处理结果
                result_key = f"download_result:{file_id}"
                result_data = {
                    'id': file_id,
                    'success': success,
                    'message': message,
                    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                }
                self.redis_client.setex(result_key, 86400, json.dumps(result_data))
                
                return success
                
            except json.JSONDecodeError:
                logger.error(f"无效的任务格式: {task_json}")
                return False
            except Exception as e:
                logger.error(f"处理任务时出错: {str(e)}", exc_info=True)
                return False
        
        return None  # 没有获取到任务
    
    def run(self, mode='forever'):
        """运行工作节点
        mode: 'forever' - 持续运行直到退出标志
              'once' - 处理一个任务后退出
              'batch' - 处理一批任务后退出
        """
        logger.info(f"下载工作节点启动,运行模式: {mode}")
        
        try:
            processed_count = 0
            success_count = 0
            
            if mode == 'forever':
                while self.running:
                    result = self.process_task()
                    
                    if result is not None:  # 处理了一个任务
                        processed_count += 1
                        if result:
                            success_count += 1
                            
                        # 每处理100个任务打印一次统计信息
                        if processed_count % 100 == 0:
                            logger.info(f"已处理任务: {processed_count}, 成功: {success_count}, 失败: {processed_count - success_count}")
            
            elif mode == 'once':
                self.process_task()
            
            elif mode == 'batch':
                while True:
                    result = self.process_task()
                    if result is None:  # 没有更多任务
                        break
                    
                    processed_count += 1
                    if result:
                        success_count += 1
            
            logger.info(f"工作节点退出,处理结果: 成功={success_count}, 失败={processed_count-success_count}, 总计={processed_count}")
            
        except KeyboardInterrupt:
            logger.info("用户中断,程序退出")
            self.running = False
        except Exception as e:
            logger.critical(f"程序异常: {str(e)}", exc_info=True)
            self.running = False
    
    def set_exit_flag(self):
        """设置退出标志"""
        self.redis_client.set(self.exit_flag_key, '1')
        logger.info("已设置退出标志")

class TaskDistributor:
    def __init__(self, db_config, redis_config, batch_size=1000):
        """初始化任务分发器"""
        self.db_config = db_config
        self.redis_config = redis_config
        self.batch_size = batch_size
        
        self.redis_client = redis.Redis(**redis_config)
        self.queue_name = "download_queue"
        self.exit_flag_key = "download_exit_flag"
    
    def _get_db_connection(self):
        """获取数据库连接"""
        return pymysql.connect(**self.db_config)
    
    def distribute_tasks(self):
        """从数据库获取待下载任务并分发到队列"""
        # 检查退出标志
        if self.redis_client.get(self.exit_flag_key):
            logger.info("检测到退出标志,任务分发器准备退出...")
            return False
        
        with self._get_db_connection() as conn:
            with conn.cursor(pymysql.cursors.DictCursor) as cursor:
                # 获取待下载的文件
                cursor.execute(
                    '''SELECT id, url FROM files 
                       WHERE status = 0 
                       LIMIT %s''', 
                    (self.batch_size,)
                )
                pending_files = cursor.fetchall()
                
                if pending_files:
                    # 将任务添加到队列
                    for file in pending_files:
                        self.redis_client.rpush(self.queue_name, json.dumps(file))
                    
                    # 标记这些文件为"处理中"状态
                    file_ids = [file['id'] for file in pending_files]
                    placeholders = ', '.join(['%s'] * len(file_ids))
                    
                    cursor.execute(
                        f'''UPDATE files 
                            SET status = 2, updated_at = CURRENT_TIMESTAMP  -- 2表示处理中
                            WHERE id IN ({placeholders})''',
                        file_ids
                    )
                    
                    conn.commit()
                    logger.info(f"已分发 {len(pending_files)} 个任务到下载队列")
                    return len(pending_files)
                
                logger.info("没有待下载的任务")
                return 0
    
    def run(self, interval=30):
        """运行任务分发器,定期检查并分发任务"""
        logger.info(f"任务分发器启动,检查间隔: {interval}秒")
        
        try:
            while not self.redis_client.get(self.exit_flag_key):
                count = self.distribute_tasks()
                if count < self.batch_size:  # 如果任务不足一批,等待更长时间
                    time.sleep(interval)
                else:
                    time.sleep(1)  # 任务充足时,快速检查
            
            logger.info("任务分发器已退出")
            
        except KeyboardInterrupt:
            logger.info("用户中断,任务分发器退出")
        except Exception as e:
            logger.critical(f"任务分发器异常: {str(e)}", exc_info=True)
    
    def set_exit_flag(self):
        """设置退出标志"""
        self.redis_client.set(self.exit_flag_key, '1')
        logger.info("已设置退出标志")

# 使用示例
if __name__ == "__main__":
    # 解析命令行参数
    parser = argparse.ArgumentParser(description='文件下载器')
    parser.add_argument('--role', choices=['worker', 'distributor'], default='worker', help='运行角色')
    parser.add_argument('--mode', choices=['forever', 'once', 'batch'], default='forever', help='工作模式')
    parser.add_argument('--config', default='config.json', help='配置文件路径')
    
    args = parser.parse_args()
    
    # 配置数据库连接
    db_config = {
        'host': 'localhost',
        'user': 'your_username',
        'password': 'your_password',
        'database': 'your_database',
        'charset': 'utf8mb4',
        'cursorclass': pymysql.cursors.DictCursor
    }
    
    # 配置Redis连接
    redis_config = {
        'host': 'localhost',
        'port': 6379,
        'db': 0,
        'decode_responses': False  # 存储二进制数据
    }
    
    # 根据角色启动相应组件
    if args.role == 'distributor':
        # 启动任务分发器
        distributor = TaskDistributor(db_config, redis_config, batch_size=1000)
        distributor.run(interval=30)
    else:
        # 启动工作节点
        worker = DownloadWorker(db_config, redis_config, download_dir='downloads')
        worker.run(mode=args.mode)

方案优势

  1. 水平扩展
    • 可以轻松通过启动更多工作节点增加并发处理能力

    • 工作节点可以分布在不同的服务器上

  2. 任务隔离
    • 每个工作节点是独立的进程,一个节点崩溃不会影响其他节点

    • 更好地利用多核 CPU 资源

  3. 可靠性
    • Redis 队列确保任务不会丢失

    • 支持断点续传

    • 任务处理状态持久化到数据库

  4. 监控和统计
    • 记录每个任务的处理结果

    • 可以通过 Redis 查看实时处理状态

主要改进

  1. 持久运行与优雅退出
    • 添加了基于 Redis 的退出标志机制

    • 支持多种运行模式:forever(持续运行)、once(处理一个任务)、batch(处理一批任务)

    • 工作节点和分发器都会定期检查退出标志

  2. 命令行参数支持
    • 使用 argparse 解析命令行参数

    • 可以指定运行角色(worker/distributor)和工作模式

  3. 增强的错误处理
    • 捕获 KeyboardInterrupt 异常,支持 Ctrl+C 中断

    • 更详细的日志记录

    使用方法

    1. 持续运行 Worker(默认模式)
      bash
      Python downloader.py --role worker --mode forever


    2. 单次运行 Worker(测试用)
      bash
      python downloader.py --role worker --mode once


    3. 批量运行 Worker(处理完队列中所有任务后退出)
      bash
      python downloader.py --role worker --mode batch


    4. 启动任务分发器
      bash
      python downloader.py --role distributor


    5. 设置退出标志(优雅停止所有组件)
      python
      运行
      # 在交互式Python环境中执行import redis
      r = redis.Redis()r.set('download_exit_flag', '1')



    这个改进后的方案让 Worker 可以持续运行,同时提供了灵活的控制方式,满足你的需求。


相关阅读

添加新评论