#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
# Copyright 2019 黎慧剑
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""
文件传输协议模块
@module protocol
@file protocol.py
"""
import os
import sys
import threading
from typing import Iterator, Union
from io import FileIO
from HiveNetCore.utils.net_tool import NetTool
# 根据当前文件路径将包路径纳入, 在非安装的情况下可以引用到
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from HiveNetFileTransfer.saver import TransferSaver
__MOUDLE__ = 'protocol' # 模块名
__DESCRIPT__ = u'文件传输协议模块' # 模块描述
__VERSION__ = '0.1.0' # 版本
__AUTHOR__ = u'黎慧剑' # 作者
__PUBLISH__ = '2021.08.24' # 发布日期
[文档]class ProtocolFw(object):
"""
文件传输协议框架
"""
#############################
# 构造函数
#############################
[文档] def __init__(self, src_file: str, dest_file: str, is_resume: bool = True, is_overwrite: bool = False,
thread_num: int = 1, block_size: int = 4096, cache_size: int = 1024, auto_expand: bool = True,
**kwargs):
"""
初始化文件传输协议类
注: 实现类必须在初始化中完成目标端 TransferSaver 的初始化处理
@param {str} src_file - 源文件信息
@param {str} dest_file - 目标文件信息
@param {bool} is_resume=True - 指定是否续传(自动查找已下载的信息), 如果不指定续传将自动删除原来已下载临时文件
@param {bool} is_overwrite=False - 是否覆盖已有文件, 如果为否, 则目标文件已存在的情况下抛出异常
@param {int} thread_num=1 - 写入处理线程数量
@param {int} block_size=4096 - 每次传输块大小, 单位为byte
@param {int} cache_size=1024 - 单线程缓存大小, 单位为kb(注意: 真实缓存大小还需要乘以处理线程数量)
@param {bool} auto_expand=True - 是否自动扩展文件大小(否则在初始化时会自动创建指定大小的文件)
@param {kwargs} - 扩展参数, 重载类自行扩展处理所需的参数
"""
raise NotImplementedError()
#############################
# with 方法支持
#############################
def __enter__(self):
"""
with方法进入的处理
"""
return self
def __exit__(self, type, value, trace):
"""
with方法退出函数
@param {object} type - 执行异常的异常类型
@param {object} value - 执行异常的异常对象值
@param {object}} trace - 执行异常的异常trace对象
"""
# 关闭资源
self.close()
#############################
# 需支持的属性(需继承类实现)
#############################
@property
def file_size(self) -> int:
"""
获取传输文件大小
@property {int}
"""
raise NotImplementedError()
#############################
# 工具函数
#############################
[文档] def pause(self):
"""
通知协议暂停传输, 保存当前状态
"""
raise NotImplementedError()
[文档] def close(self):
"""
关闭传输协议对象
"""
raise NotImplementedError()
#############################
# 文件读取的工具函数(需继承类实现)
#############################
[文档] def open_file(self, index: int = 0):
"""
打开文件并返回文件对象
@param {int} index=0 - 当前的数据处理线程索引
@returns {dict} - 打开的文件对象属性字典
{'handle': FileIO, 'close_able': 是否可关闭, lock: 锁对象}
"""
raise NotImplementedError()
[文档] def close_file(self, index: int, is_force: bool = False):
"""
关闭打开的文件
@param {int} index - 要关闭的文件对象对应的处理线程索引
@param {bool} is_force=False - 指示是否强制关闭
"""
raise NotImplementedError()
[文档] def read_file_data(self, index: int, handle, start: int, size: int,
lock: threading.RLock) -> bytes:
"""
获取文件指定位置数据
@param {int} index - 处理读取的线程索引
@param {object} handle - 打开的文件句柄
@param {int} start - 要获取的数据开始位置
@param {int} size - 要获取的数据大小
@param {threading.RLock} lock - 读取数据的锁对象
@returns {bytes} - 获取到的数据字典
注: 如果开始位置超过文件大小, 将返回b''; 如果要获取的数据大小超过文件, 则返回真实的数据大小
对于无法预知文件大小的情况, 如果返回b''也代表着文件结束
"""
raise NotImplementedError()
#############################
# 写入对象的工具函数(需继承类实现)
#############################
[文档] def open_writer(self, index: int = 0) -> dict:
"""
打开写入对象并返回对象属性
@param {int} index=0 - 当前的数据处理线程索引
@returns {dict} - 打开的文件对象属性字典
{'handle': 写入对象, 'close_able': 是否可关闭, lock: 锁对象}
"""
raise NotImplementedError()
[文档] def close_writer(self, index: int, is_force: bool = False):
"""
关闭打开的写入对象
@param {int} index - 要关闭的写入对象对应的处理线程索引
@param {bool} is_force=False - 指示是否强制关闭
"""
raise NotImplementedError()
[文档] def write_data(self, handle, lock: threading.RLock, index: int = 0, start: int = None,
size: int = None, data: bytes = None) -> dict:
"""
写入文件数据
@param {object} handle - 写入对象
@param {threading.RLock} lock - 锁定写入操作的锁对象
@param {int} index=0 - 指定写入数据的线程索引
@param {int} start=None - 数据在文件的开始位置, 如果传空代表请求该线程索引对应的获取任务信息
@param {int} size=None - 传入数据的长度
@param {bytes} data=None - 传入数据字节数组
@returns {dict} - 返回下一个任务要获取的信息字典, 格式为:
{
'status': 0, # 状态, 0-成功, 1-开始位置与线程缓存不一致, 2-全部下载完成, 3-文件md5校验失败
'index': 0, # 当前线程索引
'start': -1, # 开始位置, 如果传入-1代表该线程已无获取任务
'size': 0, # 要获取数据的大小
}
"""
raise NotImplementedError()
[文档] def file_finished(self):
"""
通知数据保存对象文件已结束
"""
raise NotImplementedError()
[文档] def flush_cache(self):
"""
强制将缓存数据写入实际文件
"""
raise NotImplementedError()
[文档] def get_thread_num(self) -> int:
"""
获取支持处理的线程数
@returns {int} - 线程数
"""
raise NotImplementedError()
[文档] def get_saver_info(self) -> dict:
"""
获取数据保存信息
@returns {dict} - 已保存的信息字典
{
'file_size': -1, # 要接收的文件大小, -1 代表不确定文件实际大小
'write_size': 0, # 已写入的数据大小
'md5': '', # 文件的md5值
}
"""
raise NotImplementedError()
[文档] def get_extend_info(self) -> dict:
"""
获取保存的扩展信息字典
@returns {dict} - 返回保存的扩展信息字典
"""
raise NotImplementedError()
[文档]class LocalProtocol(ProtocolFw):
"""
本地文件传输至本地的传输协议(复制)
"""
#############################
# 构造函数
#############################
[文档] def __init__(self, src_file: Union[str, FileIO], dest_file: str, is_resume: bool = True, is_overwrite: bool = False,
thread_num: int = 1, block_size: int = 4096, cache_size: int = 1024, auto_expand: bool = True,
**kwargs):
"""
初始化文件传输协议类
注: 实现类必须在初始化中完成目标端 TransferSaver 的初始化处理
@param {str|FileIO} src_file - 源文件路径或已打开的文件句柄
@param {str} dest_file - 目标文件路径
@param {bool} is_resume=True - 指定是否续传(自动查找已下载的信息), 如果不指定续传将自动删除原来已下载临时文件
@param {bool} is_overwrite=False - 是否覆盖已有文件, 如果为否, 则目标文件已存在的情况下抛出异常
@param {int} thread_num=1 - 写入处理线程数量
@param {int} block_size=4096 - 每次传输块大小, 单位为byte
@param {int} cache_size=1024 - 单线程缓存大小, 单位为kb(注意: 真实缓存大小还需要乘以处理线程数量)
@param {bool} auto_expand=True - 是否自动扩展文件大小(否则在初始化时会自动创建指定大小的文件)
@param {kwargs} - 扩展参数, 重载类自行扩展处理所需的参数
"""
# 要保存的参数
self.src_file = src_file
self.dest_file = dest_file
self.is_resume = is_resume
self.is_overwrite = is_overwrite
self.thread_num = thread_num
self.block_size = block_size
self.cache_size = cache_size
self.auto_expand = auto_expand
self.kwargs = kwargs
# 文件访问句柄字典, 供open_file处理使用, key为处理线程索引, value为{'handle': ..., 'close_able':..., lock: ...}
self._file_handles = dict()
self._file_handles_lock = threading.RLock() # 控制打开关闭文件的锁
self._mutiple_read = False # 控制是否允许多线程读的变量
# 写入对象字典, 供open_writer处理使用, key为处理线程索引, value为{'handle': ..., 'close_able':..., lock: ...}
self._writer_handles = dict()
self._writer_handles_lock = threading.RLock() # 控制打开关闭文件的锁
self._mutiple_write = False # 控制是否允许多线程写的变量
# 初始化数据接收对象
self.init_saver()
#############################
# 需支持的属性(需继承类实现)
#############################
@property
def file_size(self) -> int:
"""
获取传输文件大小
@property {int}
"""
return self._file_size
#############################
# 工具函数
#############################
[文档] def pause(self):
"""
通知协议暂停传输, 保存当前状态
"""
self.flush_cache()
[文档] def close(self):
"""
关闭传输协议对象
"""
# 销毁数据接收对象
self.destroy_saver()
# 关闭写入对象
_keys = list(self._writer_handles.keys())
for _index in _keys:
self.close_writer(_index, is_force=True)
# 关闭文件
if type(self.src_file) == str:
_keys = list(self._file_handles.keys())
for _index in _keys:
self.close_file(_index, is_force=True)
#############################
# 文件读取的工具函数
#############################
[文档] def get_file_size(self) -> int:
"""
获取文件的大小
@returns {int} - 文件大小, 如果不支持获取文件大小返回 None
"""
if type(self.src_file) == str:
return os.path.getsize(self.src_file)
else:
# 移动指针到文件结尾, 指针位置就是文件大小
return self.src_file.seek(0, 2)
[文档] def get_file_md5(self) -> str:
"""
获取文件的md5值
@returns {str} - 文件md5值, 如果获取不到md5值返回None
"""
return NetTool.get_file_md5(self.src_file)
#############################
# 文件读取的工具函数(需继承类实现)
#############################
[文档] def open_file(self, index: int = 0):
"""
打开文件并返回文件对象
@param {int} index=0 - 当前的数据处理线程索引
@returns {dict} - 打开的文件对象属性字典
{'handle': FileIO, 'close_able': 是否可关闭, lock: 锁对象}
"""
self._mutiple_read = False
if type(self.src_file) != str:
# FileIO 的对象不允许多线程读
self._mutiple_read = False
self._file_handles_lock.acquire()
try:
# 获取文件对象属性字典
if self._mutiple_read:
# 允许多线程读访问
_file_dict = self._file_handles.get(index, None)
if _file_dict is None:
# 获取不到, 创建新文件访问对象
_lock = threading.RLock()
_file_dict = {
'handle': open(self.src_file, 'rb'), 'close_able': True, 'lock': _lock
}
self._file_handles[index] = _file_dict
else:
# 单线程读写模式, 只允许访问第0个, 并且文件不允许关闭
_file_dict = self._file_handles.get(0, None)
if _file_dict is None:
_lock = threading.RLock()
if type(self.src_file) == str:
# 文件路径
_file_dict = {
'handle': open(self.src_file, 'rb'), 'close_able': False, 'lock': _lock
}
else:
_file_dict = {
'handle': self.src_file, 'close_able': False, 'lock': _lock
}
self._file_handles[0] = _file_dict
finally:
self._file_handles_lock.release()
# 返回结果
return _file_dict
[文档] def close_file(self, index: int, is_force: bool = False):
"""
关闭打开的文件
@param {int} index - 要关闭的文件对象对应的处理线程索引
@param {bool} is_force=False - 指示是否强制关闭
"""
self._file_handles_lock.acquire()
try:
_file_dict = self._file_handles.get(index, None)
if _file_dict is not None and (is_force or _file_dict['close_able']):
# 允许关闭或强制关闭
self._file_handles.pop(index, None)
_file_dict['handle'].close()
finally:
self._file_handles_lock.release()
[文档] def read_file_data(self, index: int, handle: FileIO, start: int, size: int,
lock: threading.RLock) -> bytes:
"""
获取文件指定位置数据
@param {int} index - 处理读取的线程索引
@param {object} handle - 打开的文件句柄
@param {int} start - 要获取的数据开始位置
@param {int} size - 要获取的数据大小
@param {threading.RLock} lock - 读取数据的锁对象
@returns {bytes} - 获取到的数据字典
注: 如果开始位置超过文件大小, 将返回b''; 如果要获取的数据大小超过文件, 则返回真实的数据大小
对于无法预知文件大小的情况, 如果返回b''也代表着文件结束
"""
lock.acquire()
try:
# 移动到指定位置并获取数据
handle.seek(start)
_bytes = handle.read(size)
return _bytes
finally:
lock.release()
#############################
# 写入对象的工具函数
#############################
[文档] def init_saver(self):
"""
初始化数据保存对象
"""
# 处理源文件信息
self._file_size = self.get_file_size()
self._file_md5 = self.get_file_md5()
# 处理文件传输接收对象
self._saver = TransferSaver(
self.dest_file, is_resume=self.is_resume, file_size=self._file_size, md5=self._file_md5,
is_overwrite=self.is_overwrite, thread_num=self.thread_num, block_size=self.block_size,
cache_size=self.cache_size, auto_expand=self.auto_expand
)
self.thread_num = self._saver._thread_num # 线程数有可能被改变
[文档] def destroy_saver(self):
"""
销毁接收数据对象
"""
# 写入缓存并删除对象
self._saver.close()
#############################
# 写入对象的工具函数(需继承类实现)
#############################
[文档] def open_writer(self, index: int = 0) -> dict:
"""
打开写入对象并返回对象属性
@param {int} index=0 - 当前的数据处理线程索引
@returns {dict} - 打开的文件对象属性字典
{'handle': 写入对象, 'close_able': 是否可关闭, lock: 锁对象}
"""
# 本地文件无需特别处理, 直接返回None就好
return {'handle': None, 'close_able': False, 'lock': None}
[文档] def close_writer(self, index: int, is_force: bool = False):
"""
关闭打开的写入对象
@param {int} index - 要关闭的写入对象对应的处理线程索引
@param {bool} is_force=False - 指示是否强制关闭
"""
# 本地文件无需特别处理
pass
[文档] def write_data(self, handle, lock: threading.RLock, index: int = 0, start: int = None,
size: int = None, data: bytes = None) -> dict:
"""
写入文件
@param {object} handle - 写入对象
@param {threading.RLock} lock - 锁定写入操作的锁对象
@param {int} index=0 - 指定写入数据的线程索引
@param {int} start=None - 数据在文件的开始位置, 如果传空代表请求该线程索引对应的获取任务信息
@param {int} size=None - 传入数据的长度
@param {bytes} data=None - 传入数据字节数组
@returns {dict} - 返回下一个任务要获取的信息字典, 格式为:
{
'status': 0, # 状态, 0-成功, 1-开始位置与线程缓存不一致, 2-全部下载完成, 3-文件md5校验失败
'index': 0, # 当前线程索引
'start': -1, # 开始位置, 如果传入-1代表该线程已无获取任务
'size': 0, # 要获取数据的大小
}
"""
# 本地文件模式, 无需使用handle和lock参数
return self._saver.write_data(
index=index, start=start, size=size, data=data
)
[文档] def file_finished(self):
"""
通知数据保存对象文件已结束
"""
self._saver.finished()
[文档] def flush_cache(self):
"""
强制将缓存数据写入实际文件
"""
self._saver.flush()
[文档] def get_thread_num(self) -> int:
"""
获取支持处理的线程数
@returns {int} - 线程数
"""
return self.thread_num
[文档] def get_saver_info(self) -> dict:
"""
获取数据保存信息
@returns {dict} - 已保存的信息字典
{
'file_size': -1, # 要接收的文件大小, -1 代表不确定文件实际大小
'write_size': 0, # 已写入的数据大小
'md5': '', # 文件的md5值
}
"""
return self._saver.get_save_info()
[文档] def get_extend_info(self) -> dict:
"""
获取保存的扩展信息字典
@returns {dict} - 返回保存的扩展信息字典
"""
return self._saver.get_extend_info()
if __name__ == '__main__':
# 当程序自己独立运行时执行的操作
# 打印版本信息
print(('模块名: %s - %s\n'
'作者: %s\n'
'发布日期: %s\n'
'版本: %s' % (__MOUDLE__, __DESCRIPT__, __AUTHOR__, __PUBLISH__, __VERSION__)))