oss2.resumable 源代码

优质
小牛编辑
129浏览
2023-12-01
# -*- coding: utf-8 -*-

"""
oss2.resumable
~~~~~~~~~~~~~~

The module contains the classes for resumable upload.
"""

import os

from . import utils
from . import iterators
from . import exceptions
from . import defaults
from .api import Bucket

from .models import PartInfo
from .compat import json, stringify, to_unicode
from .task_queue import TaskQueue
from .defaults import get_logger

import functools
import threading
import random
import string

import shutil

_MAX_PART_COUNT = 10000
_MIN_PART_SIZE = 100 * 1024


[文档]def resumable_upload(bucket, key, filename,
                     store=None,
                     headers=None,
                     multipart_threshold=None,
                     part_size=None,
                     progress_callback=None,
                     num_threads=None):
    """Uses multipart upload to upload a local file.
    
	The `oss2.defaults.multipart_num_threads` is used as the default parallel thread count. The upload progress is saved to the local disk. 
	If the upload is interrupted, and a new upload is started with the same local file and destination, 
	the upload will resume from where the last upload stopped based on the checkpoint file. 
	This function saves the upload progress to the HOME folder on the local disk by default. 
    
    .. Note::
        #. If CryptoBucket is used, the function degrades to a normal upload operation. 
    
    :param bucket: :class:`Bucket <oss2.Bucket>` instance.
    :param key: The OSS object key.
    :param filename: The name of the local file to be uploaded.
    :param store: Upload progress information storage. ResumableStore is used if not specified. See `ResumableStore` for more information.
    :param headers: HTTP headers for `put_object` or `init_multipart_upload`.
    :param multipart_threshold: Files exceeding this size will be uploaded with multipart upload.
    :param part_size: Part size. The value is calculated automatically if not specified.
    :param progress_callback: The progress callback. See ref:`progress_callback` for more information.
    :param num_threads: Upload parallel thread count. `oss2.defaults.multipart_num_threads` will be used if not specified.
    """
    size = os.path.getsize(filename)
    multipart_threshold = defaults.get(multipart_threshold, defaults.multipart_threshold)

    if isinstance(bucket, Bucket) and size >= multipart_threshold:
        uploader = _ResumableUploader(bucket, key, filename, size, store,
                                      part_size=part_size,
                                      headers=headers,
                                      progress_callback=progress_callback,
                                      num_threads=num_threads)
        result = uploader.upload()
    else:
        with open(to_unicode(filename), 'rb') as f:
            result = bucket.put_object(key, f,
                              headers=headers,
                              progress_callback=progress_callback)
    
    return result

[文档]def resumable_download(bucket, key, filename,
                       multiget_threshold=None,
                       part_size=None,
                       progress_callback=None,
                       num_threads=None,
                       store=None):
    """Resumable download.

    The implementation : 
        #. Create a temporary file with the same original file name plus a random suffix.
        #. Download the OSS file with specified `Range` into the temporary file.
        #. Once the download is finished, rename the temp file with the target file name.
	
	During a download, the checkpoint information (finished range) is stored in the disk as a checkpoint file. 
	
	If the download is interrupted, the download can resume from the checkpoint file if the source and target files match. Only the missing parts will be downloaded.

	By default, the checkpoint file is in the Home subfolder. The subfolder for storing the checkpoint file can be specified through the store parameter.

    .. Note::
        #. For the same source and target file, at any given time, there should be only one running instance of this API. Otherwise multiple calls could lead to checkpoint files overwriting each other.
        #. Don't use a small part size. The suggested size is no less than `oss2.defaults.multiget_part_size`.
        #. The API will overwrite the target file if it exists already.
        #. If CryptoBucket is used, the function will become a normal download.

    :param bucket: :class:`Bucket <oss2.Bucket>` instance.
    :param str key: OSS key object.
    :param str filename: Local file name.
    :param int multiget_threshold: The threshold of the file size to use multiget download.
    :param int part_size: The preferred part size. The actual part size might be slightly different according to determine_part_size().
    :param progress_callback: Progress callback. See :ref:`progress_callback` for more information.
    :param num_threads: Parallel thread number. The default value is `oss2.defaults.multiget_num_threads`.

    :param store: Specifies the persistent storage for checkpoint information. For example, the folder of the checkpoint file.
    :type store: :class:`ResumableDownloadStore <oss2.models.ResumableDownloadStore>`

    :raises: If the source OSS file does not exist :class:`NotFound <oss2.exceptions.NotFound>` is thrown . Other exceptions may be thrown due to other issues.
    """

    multiget_threshold = defaults.get(multiget_threshold, defaults.multiget_threshold)

    if isinstance(bucket, Bucket):
        result = bucket.head_object(key)
        if result.content_length >= multiget_threshold:
            downloader = _ResumableDownloader(bucket, key, filename, _ObjectInfo.make(result),
                                              part_size=part_size,
                                              progress_callback=progress_callback,
                                              num_threads=num_threads,
                                              store=store)
            downloader.download()
        else:
            bucket.get_object_to_file(key, filename,
                                  progress_callback=progress_callback)
    else:
        bucket.get_object_to_file(key, filename,
                                  progress_callback=progress_callback)


_MAX_MULTIGET_PART_COUNT = 100


[文档]def determine_part_size(total_size,
                        preferred_size=None):
    """Determine the part size of the multiparts upload.

    :param int total_size: Total size to upload.
    :param int preferred_size: User's preferred size. By default it is defaults.part_size .

    :return: Part size.
    """
    if not preferred_size:
        preferred_size = defaults.part_size

    return _determine_part_size_internal(total_size, preferred_size, _MAX_PART_COUNT)


def _determine_part_size_internal(total_size, preferred_size, max_count):
    if total_size < preferred_size:
        return total_size

    if preferred_size * max_count < total_size:
        if total_size % max_count:
            return total_size // max_count + 1
        else:
            return total_size // max_count
    else:
        return preferred_size


def _split_to_parts(total_size, part_size):
    parts = []
    num_parts = utils.how_many(total_size, part_size)

    for i in range(num_parts):
        if i == num_parts - 1:
            start = i * part_size
            end = total_size
        else:
            start = i * part_size
            end = part_size + start

        parts.append(_PartToProcess(i + 1, start, end))

    return parts


class _ResumableOperation(object):
    def __init__(self, bucket, key, filename, size, store,
                 progress_callback=None):
        self.bucket = bucket
        self.key = key
        self.filename = filename
        self.size = size

        self._abspath = os.path.abspath(filename)

        self.__store = store
        self.__record_key = self.__store.make_store_key(bucket.bucket_name, key, self._abspath)
        get_logger().info('key is {0}'.format(self.__record_key))

        # protect self.__progress_callback
        self.__plock = threading.Lock()
        self.__progress_callback = progress_callback

    def _del_record(self):
        self.__store.delete(self.__record_key)

    def _put_record(self, record):
        self.__store.put(self.__record_key, record)

    def _get_record(self):
        return self.__store.get(self.__record_key)

    def _report_progress(self, consumed_size):
        if self.__progress_callback:
            with self.__plock:
                self.__progress_callback(consumed_size, self.size)


class _ObjectInfo(object):
    def __init__(self):
        self.size = None
        self.etag = None
        self.mtime = None

    @staticmethod
    def make(head_object_result):
        objectInfo = _ObjectInfo()
        objectInfo.size = head_object_result.content_length
        objectInfo.etag = head_object_result.etag
        objectInfo.mtime = head_object_result.last_modified

        return objectInfo


class _ResumableDownloader(_ResumableOperation):
    def __init__(self, bucket, key, filename, objectInfo,
                 part_size=None,
                 store=None,
                 progress_callback=None,
                 num_threads=None):
        super(_ResumableDownloader, self).__init__(bucket, key, filename, objectInfo.size,
                                                   store or ResumableDownloadStore(),
                                                   progress_callback=progress_callback)
        self.objectInfo = objectInfo

        self.__part_size = defaults.get(part_size, defaults.multiget_part_size)
        self.__part_size = _determine_part_size_internal(self.size, self.__part_size, _MAX_MULTIGET_PART_COUNT)

        self.__tmp_file = None
        self.__num_threads = defaults.get(num_threads, defaults.multiget_num_threads)
        self.__finished_parts = None
        self.__finished_size = None

        # protect record
        self.__lock = threading.Lock()
        self.__record = None

    def download(self):
        self.__load_record()

        parts_to_download = self.__get_parts_to_download()

        # create tmp file if it is does not exist
        open(self.__tmp_file, 'a').close()

        q = TaskQueue(functools.partial(self.__producer, parts_to_download=parts_to_download),
                      [self.__consumer] * self.__num_threads)
        q.run()

        utils.force_rename(self.__tmp_file, self.filename)

        self._report_progress(self.size)
        self._del_record()

    def __producer(self, q, parts_to_download=None):
        for part in parts_to_download:
            q.put(part)

    def __consumer(self, q):
        while q.ok():
            part = q.get()
            if part is None:
                break

            self.__download_part(part)

    def __download_part(self, part):
        self._report_progress(self.__finished_size)

        with open(self.__tmp_file, 'rb+') as f:
            f.seek(part.start, os.SEEK_SET)

            headers = {'If-Match': self.objectInfo.etag,
                       'If-Unmodified-Since': utils.http_date(self.objectInfo.mtime)}
            result = self.bucket.get_object(self.key, byte_range=(part.start, part.end - 1), headers=headers)
            utils.copyfileobj_and_verify(result, f, part.end - part.start, request_id=result.request_id)

        self.__finish_part(part)

    def __load_record(self):
        record = self._get_record()

        if record and not self.is_record_sane(record):
            self._del_record()
            record = None

        if record and not os.path.exists(self.filename + record['tmp_suffix']):
            self._del_record()
            record = None

        if record and self.__is_remote_changed(record):
            utils.silently_remove(self.filename + record['tmp_suffix'])
            self._del_record()
            record = None

        if not record:
            record = {'mtime': self.objectInfo.mtime, 'etag': self.objectInfo.etag, 'size': self.objectInfo.size,
                      'bucket': self.bucket.bucket_name, 'key': self.key, 'part_size': self.__part_size,
                      'tmp_suffix': self.__gen_tmp_suffix(), 'abspath': self._abspath,
                      'parts': []}
            self._put_record(record)

        self.__tmp_file = self.filename + record['tmp_suffix']
        self.__part_size = record['part_size']
        self.__finished_parts = list(_PartToProcess(p['part_number'], p['start'], p['end']) for p in record['parts'])
        self.__finished_size = sum(p.size for p in self.__finished_parts)
        self.__record = record

    def __get_parts_to_download(self):
        assert self.__record

        all_set = set(_split_to_parts(self.size, self.__part_size))
        finished_set = set(self.__finished_parts)

        return sorted(list(all_set - finished_set), key=lambda p: p.part_number)

    @staticmethod
    def is_record_sane(record):
        try:
            for key in ('etag', 'tmp_suffix', 'abspath', 'bucket', 'key'):
                if not isinstance(record[key], str):
                    get_logger().info('{0} is not a string: {1}, but {2}'.format(key, record[key], record[key].__class__))
                    return False

            for key in ('part_size', 'size', 'mtime'):
                if not isinstance(record[key], int):
                    get_logger().info('{0} is not an integer: {1}, but {2}'.format(key, record[key], record[key].__class__))
                    return False

            for key in ('parts'):
                if not isinstance(record['parts'], list):
                    get_logger().info('{0} is not a list: {1}, but {2}'.format(key, record[key], record[key].__class__))
                    return False
        except KeyError as e:
            get_logger().info('Key not found: {0}'.format(e.args))
            return False

        return True

    def __is_remote_changed(self, record):
        return (record['mtime'] != self.objectInfo.mtime or
            record['size'] != self.objectInfo.size or
            record['etag'] != self.objectInfo.etag)

    def __finish_part(self, part):
        get_logger().debug('finishing part: part_number={0}, start={1}, end={2}'.format(part.part_number, part.start, part.end))

        with self.__lock:
            self.__finished_parts.append(part)
            self.__finished_size += part.size

            self.__record['parts'].append({'part_number': part.part_number,
                                           'start': part.start,
                                           'end': part.end})
            self._put_record(self.__record)

    def __gen_tmp_suffix(self):
        return '.tmp-' + ''.join(random.choice(string.ascii_lowercase) for i in range(12))


class _ResumableUploader(_ResumableOperation):
    """Resumable upload.

    :param bucket: :class:`Bucket <oss2.Bucket>` instance
    :param key: OSS object key.
    :param filename: The file name to upload.
    :param size: Total file size.
    :param store: The store for persisting checkpoint information.
    :param headers: The HTTP headers for `init_multipart_upload`
    :param part_size: Part size. If it's specified, then it has higher priority than the calculated part size. If not specified, for the retry upload, the original upload's part size will be used.

    :param progress_callback: Progress callback. Check out :ref:`progress_callback`.
    """
    def __init__(self, bucket, key, filename, size,
                 store=None,
                 headers=None,
                 part_size=None,
                 progress_callback=None,
                 num_threads=None):
        super(_ResumableUploader, self).__init__(bucket, key, filename, size,
                                                 store or ResumableStore(),
                                                 progress_callback=progress_callback)

        self.__headers = headers
        self.__part_size = defaults.get(part_size, defaults.part_size)

        self.__mtime = os.path.getmtime(filename)

        self.__num_threads = defaults.get(num_threads, defaults.multipart_num_threads)

        self.__upload_id = None

        # protect below fields
        self.__lock = threading.Lock()
        self.__record = None
        self.__finished_size = 0
        self.__finished_parts = None

    def upload(self):
        self.__load_record()

        parts_to_upload = self.__get_parts_to_upload(self.__finished_parts)
        parts_to_upload = sorted(parts_to_upload, key=lambda p: p.part_number)

        q = TaskQueue(functools.partial(self.__producer, parts_to_upload=parts_to_upload),
                      [self.__consumer] * self.__num_threads)
        q.run()

        self._report_progress(self.size)

        result = self.bucket.complete_multipart_upload(self.key, self.__upload_id, self.__finished_parts)
        self._del_record()
        
        return result

    def __producer(self, q, parts_to_upload=None):
        for part in parts_to_upload:
            q.put(part)

    def __consumer(self, q):
        while True:
            part = q.get()
            if part is None:
                break

            self.__upload_part(part)

    def __upload_part(self, part):
        with open(to_unicode(self.filename), 'rb') as f:
            self._report_progress(self.__finished_size)

            f.seek(part.start, os.SEEK_SET)
            result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
                                             utils.SizedFileAdapter(f, part.size))

            self.__finish_part(PartInfo(part.part_number, result.etag, size=part.size))

    def __finish_part(self, part_info):
        with self.__lock:
            self.__finished_parts.append(part_info)
            self.__finished_size += part_info.size

            self.__record['parts'].append({'part_number': part_info.part_number, 'etag': part_info.etag})
            self._put_record(self.__record)

    def __load_record(self):
        record = self._get_record()

        if record and not _is_record_sane(record):
            self._del_record()
            record = None

        if record and self.__file_changed(record):
            get_logger().debug('{0} was changed, clear the record.'.format(self.filename))
            self._del_record()
            record = None

        if record and not self.__upload_exists(record['upload_id']):
            get_logger().debug('{0} upload not exist, clear the record.'.format(record['upload_id']))
            self._del_record()
            record = None

        if not record:
            part_size = determine_part_size(self.size, self.__part_size)
            upload_id = self.bucket.init_multipart_upload(self.key, headers=self.__headers).upload_id
            record = {'upload_id': upload_id, 'mtime': self.__mtime, 'size': self.size, 'parts': [],
                      'abspath': self._abspath, 'bucket': self.bucket.bucket_name, 'key': self.key,
                      'part_size': part_size}

            get_logger().debug('put new record upload_id={0} part_size={1}'.format(upload_id, part_size))
            self._put_record(record)

        self.__record = record
        self.__part_size = self.__record['part_size']
        self.__upload_id = self.__record['upload_id']
        self.__finished_parts = self.__get_finished_parts()
        self.__finished_size = sum(p.size for p in self.__finished_parts)

    def __get_finished_parts(self):
        last_part_number = utils.how_many(self.size, self.__part_size)

        parts = []

        for p in self.__record['parts']:
            part_info = PartInfo(int(p['part_number']), p['etag'])
            if part_info.part_number == last_part_number:
                part_info.size = self.size % self.__part_size
            else:
                part_info.size = self.__part_size

            parts.append(part_info)

        return parts

    def __upload_exists(self, upload_id):
        try:
            list(iterators.PartIterator(self.bucket, self.key, upload_id, '0', max_parts=1))
        except exceptions.NoSuchUpload:
            return False
        else:
            return True

    def __file_changed(self, record):
        return record['mtime'] != self.__mtime or record['size'] != self.size

    def __get_parts_to_upload(self, parts_uploaded):
        all_parts = _split_to_parts(self.size, self.__part_size)
        if not parts_uploaded:
            return all_parts

        all_parts_map = dict((p.part_number, p) for p in all_parts)

        for uploaded in parts_uploaded:
            if uploaded.part_number in all_parts_map:
                del all_parts_map[uploaded.part_number]

        return all_parts_map.values()


_UPLOAD_TEMP_DIR = '.py-oss-upload'
_DOWNLOAD_TEMP_DIR = '.py-oss-download'


class _ResumableStoreBase(object):
    def __init__(self, root, dir):
        self.dir = os.path.join(root, dir)

        if os.path.isdir(self.dir):
            return

        utils.makedir_p(self.dir)

    def get(self, key):
        pathname = self.__path(key)

        get_logger().debug('get key={0}, pathname={1}'.format(key, pathname))

        if not os.path.exists(pathname):
            return None

        # json.load() returns unicode. For Python2, it's converted to str.

        try:
            with open(to_unicode(pathname), 'r') as f:
                content = json.load(f)
        except ValueError:
            os.remove(pathname)
            return None
        else:
            return stringify(content)

    def put(self, key, value):
        pathname = self.__path(key)

        with open(to_unicode(pathname), 'w') as f:
            json.dump(value, f)

        get_logger().debug('put key={0}, pathname={1}'.format(key, pathname))

    def delete(self, key):
        pathname = self.__path(key)
        os.remove(pathname)

        get_logger().debug('del key={0}, pathname={1}'.format(key, pathname))

    def __path(self, key):
        return os.path.join(self.dir, key)


def _normalize_path(path):
    return os.path.normpath(os.path.normcase(path))


[文档]class ResumableStore(_ResumableStoreBase):
    """The class for persisting uploading checkpoint information.

    The checkpoint information would be a subfolder of `root/dir/`

    :param str root: Root folder, default is `HOME`.
    :param str dir: Subfolder, default is `_UPLOAD_TEMP_DIR`. 
    """
    def __init__(self, root=None, dir=None):
        super(ResumableStore, self).__init__(root or os.path.expanduser('~'), dir or _UPLOAD_TEMP_DIR)

[文档]    @staticmethod
    def make_store_key(bucket_name, key, filename):
        filepath = _normalize_path(filename)

        oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
        return utils.md5_string(oss_pathname) + '-' + utils.md5_string(filepath)


[文档]class ResumableDownloadStore(_ResumableStoreBase):
    """The class for persisting downloading checkpoint information.

    The checkpoint information is saved in a subfolder of `root/dir/`. 

    :param str root: Root folder, default is `HOME`.
    :param str dir: Subfolder, default is `_UPLOAD_TEMP_DIR`
    """
    def __init__(self, root=None, dir=None):
        super(ResumableDownloadStore, self).__init__(root or os.path.expanduser('~'), dir or _DOWNLOAD_TEMP_DIR)

[文档]    @staticmethod
    def make_store_key(bucket_name, key, filename):
        filepath = _normalize_path(filename)

        oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
        return utils.md5_string(oss_pathname) + '-' + utils.md5_string(filepath) + '-download'


[文档]def make_upload_store(root=None, dir=None):
    return ResumableStore(root=root, dir=dir)


[文档]def make_download_store(root=None, dir=None):
    return ResumableDownloadStore(root=root, dir=dir)


def _rebuild_record(filename, store, bucket, key, upload_id, part_size=None):
    abspath = os.path.abspath(filename)
    mtime = os.path.getmtime(filename)
    size = os.path.getsize(filename)

    store_key = store.make_store_key(bucket.bucket_name, key, abspath)
    record = {'upload_id': upload_id, 'mtime': mtime, 'size': size, 'parts': [],
              'abspath': abspath, 'key': key}

    for p in iterators.PartIterator(bucket, key, upload_id):
        record['parts'].append({'part_number': p.part_number,
                                'etag': p.etag})

        if not part_size:
            part_size = p.size

    record['part_size'] = part_size

    store.put(store_key, record)


def _is_record_sane(record):
    try:
        for key in ('upload_id', 'abspath', 'key'):
            if not isinstance(record[key], str):
                get_logger().info('{0} is not a string: {1}, but {2}'.format(key, record[key], record[key].__class__))
                return False

        for key in ('size', 'part_size'):
            if not isinstance(record[key], int):
                get_logger().info('{0} is not an integer: {1}'.format(key, record[key]))
                return False

        if not isinstance(record['mtime'], int) and not isinstance(record['mtime'], float):
            get_logger().info('mtime is not a float or an integer: {0}'.format(record['mtime']))
            return False

        if not isinstance(record['parts'], list):
            get_logger().info('parts is not a list: {0}'.format(record['parts'].__class__.__name__))
            return False
    except KeyError as e:
        get_logger().info('Key not found: {0}'.format(e.args))
        return False

    return True


class _PartToProcess(object):
    def __init__(self, part_number, start, end):
        self.part_number = part_number
        self.start = start
        self.end = end

    @property
    def size(self):
        return self.end - self.start

    def __hash__(self):
        return hash(self.__key())

    def __eq__(self, other):
        return self.__key() == other.__key()

    def __key(self):
        return (self.part_number, self.start, self.end)