Last active
May 15, 2023 00:34
-
-
Save a-canela/4cbbe20b08ce1fa92ff373d5b60ac9ef to your computer and use it in GitHub Desktop.
Directory download (sync) through boto3 supporting include-only pattern
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import boto3 | |
| from fnmatch import fnmatch | |
| from botocore.config import Config | |
| from multiprocessing import Process, Queue | |
| from concurrent.futures import ThreadPoolExecutor | |
| class S3Client: | |
| """AWS S3 boto3-based client for downloading/syncing folders""" | |
| def __init__( | |
| self, | |
| bucket: str = '', | |
| max_attempts: int = 10, | |
| max_pool_connections: int = 100, | |
| max_download_workers: int = 20, | |
| ) -> None: | |
| """AWS S3 boto3-based client constructor. | |
| Note: Implementation only valid for UNIX. | |
| Args: | |
| bucket (str): Bucket path. | |
| max_attempts (int, optional): Maximum retry attemps. | |
| Defaults to 10 | |
| max_pool_connections (int, optional) Maximum number of concurrent requests to aws s3. | |
| Defaults to 100 | |
| max_download_workers (int, optional) Maximum number of workers for downloading files. | |
| Defaults to 20 | |
| """ | |
| self.bucket = bucket.replace('s3://', '') | |
| self.max_attempts = max_attempts | |
| config = Config(retries={ | |
| 'mode': 'standard', | |
| 'max_attempts': max_attempts, | |
| }, max_pool_connections=max_pool_connections) | |
| self.client = boto3.client('s3', config=config) | |
| self.max_download_workers = max_download_workers | |
| self.paginator = self.client.get_paginator('list_objects_v2') | |
| def download_files( | |
| self, | |
| dst_dir: str, | |
| prefix: str, | |
| pattern: str = None, | |
| bucket: str = '', | |
| ) -> None: | |
| """Download files (either a folder recursively or a single file) from S3. | |
| Args: | |
| dst_dir (str): Destination local folder. | |
| prefix (str): File/folder S3 key. | |
| pattern (str, optional): Including filter pattern. Defaults to None. | |
| bucket (str, optional): File/folder S3 bucket. Defaults to ''. | |
| """ | |
| self.downloader_queue = Queue() | |
| self.downloader = Process(target=self._downloader) | |
| self.downloader.start() | |
| try: | |
| dst_dir = dst_dir.rstrip('/') | |
| prefix = prefix.rstrip('/') | |
| bucket = bucket or self.bucket | |
| self._enqueue_downloads( | |
| dst_dir=dst_dir, | |
| prefix=prefix, | |
| pattern=pattern, | |
| bucket=bucket, | |
| ) | |
| finally: | |
| self.downloader_queue.put(None) | |
| self.downloader.join() | |
| def _enqueue_downloads( | |
| self, | |
| dst_dir: str, | |
| prefix: str, | |
| pattern: str = None, | |
| bucket: str = '', | |
| ) -> None: | |
| """Enqueue files to be downloaded by the downloader background process. | |
| Args: | |
| dst_dir (str): Destination local folder. | |
| prefix (str): File/folder S3 key. | |
| pattern (str, optional): Including filter pattern. Defaults to None. | |
| bucket (str, optional): File/folder S3 bucket. Defaults to ''. | |
| """ | |
| for page in self.paginator.paginate(Bucket=bucket, Prefix=prefix): | |
| files = page.get('Contents', ()) | |
| if len(files) == 1 and files[0].get('Key') == prefix: | |
| file_data = [ | |
| ( | |
| file.get('LastModified').timestamp(), | |
| file.get('Size'), | |
| file.get('Key'), | |
| os.path.basename(file.get('Key')), | |
| ) | |
| for file in files | |
| ] | |
| else: | |
| file_data = [ | |
| ( | |
| file.get('LastModified').timestamp(), | |
| file.get('Size'), | |
| file.get('Key'), | |
| file.get('Key').replace(f'{prefix}/', ''), | |
| ) | |
| for file in files | |
| ] | |
| if pattern: | |
| file_data = filter(lambda file_entry: fnmatch(file_entry[3], pattern), file_data) | |
| created_sub_dirs = set() | |
| for timestamp, size, key, sub_path in file_data: | |
| dst_path = f'{dst_dir}/{sub_path}' | |
| if os.path.exists(dst_path): | |
| dst_file_stat = os.stat(dst_path) | |
| if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size: | |
| continue | |
| sub_dir = os.path.dirname(sub_path) | |
| if sub_dir not in created_sub_dirs: | |
| os.makedirs(f'{dst_dir}/{sub_dir}', exist_ok=True) | |
| created_sub_dirs.add(sub_dir) | |
| self.downloader_queue.put((bucket, key, dst_path, timestamp)) | |
| def _downloader(self) -> None: | |
| """Downloader background process""" | |
| executor = ThreadPoolExecutor(max_workers=self.max_download_workers) | |
| try: | |
| args = self.downloader_queue.get() | |
| while args: | |
| executor.submit(self._download_file, *args) | |
| args = self.downloader_queue.get() | |
| finally: | |
| executor.shutdown(wait=True) | |
| def _download_file( | |
| self, | |
| bucket: str, | |
| key: str, | |
| dst_path: str, | |
| timestamp: float, | |
| ) -> None: | |
| """Download a file and assign the provided timestamp to it. | |
| Args: | |
| bucket (str): File S3 bucket. | |
| key (str): File S3 key. | |
| dst_path (str): File destination local path. | |
| timestamp (float): File S3 timestamp. | |
| """ | |
| self.client.download_file(bucket, key, dst_path) | |
| os.utime(dst_path, (timestamp, timestamp)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment