basic download function

This commit is contained in:
Bigsk 2024-10-14 18:13:18 +08:00
parent b96a458952
commit c676cd2eb5
6 changed files with 112 additions and 4 deletions

View File

@ -1 +1,2 @@
from .config import * from .config import *
from .core import *

View File

@ -0,0 +1 @@

View File

@ -5,4 +5,4 @@ __version__ = '0.0.1'
__author__ = "Ian Xia" __author__ = "Ian Xia"
__author_email__ = "xia@ghink.net" __author_email__ = "xia@ghink.net"
__license__ = "MIT" __license__ = "MIT"
__copyright__ = "Copyright Ian Xia" __copyright__ = "Copyright Ian Xia"

View File

@ -1,13 +1,45 @@
import richka import richka
import logging
__VERSION = ("Alpha", 0, 0, 1) __VERSION = ("Alpha", 0, 0, 1)
USER_AGENT = f"Richka{__VERSION[0]}/{__VERSION[1]}.{__VERSION[2]}.{__VERSION[3]}" USER_AGENT = f"Richka{__VERSION[0]}/{__VERSION[1]}.{__VERSION[2]}.{__VERSION[3]}"
HEADERS = {"user-agent": USER_AGENT} HEADERS = {"user-agent": USER_AGENT}
COROUTINE_LIMIT = 10
SLICE_THRESHOLD = 10 # MiB
logger = logging.getLogger("Richka Engine")
def set_user_agent(user_agent: str) -> None: def set_user_agent(user_agent: str) -> None:
"""
Set Public User Agent for HTTP Requests
:param user_agent: String
:return:
"""
richka.USER_AGENT = user_agent richka.USER_AGENT = user_agent
richka.HEADERS["user-agent"] = user_agent richka.HEADERS["user-agent"] = user_agent
def set_headers(headers: dict) -> None: def set_headers(headers: dict) -> None:
"""
Set Public Headers for HTTP Requests
:param headers: Dictionary
:return:
"""
for key, value in headers.items(): for key, value in headers.items():
richka.HEADERS[key.lower()] = value richka.HEADERS[key.lower()] = value
def set_coroutine_limit(coroutine_limit: int) -> None:
"""
Set Coroutine Limit for HTTP Requests
:param coroutine_limit: Integer
:return:
"""
richka.COROUTINE_LIMIT = coroutine_limit
def set_slice_threshold(slice_threshold: int) -> None:
"""
Set Slice Threshold for HTTP Requests
:param slice_threshold: Integer
:return:
"""
richka.SLICE_THRESHOLD = slice_threshold

View File

@ -0,0 +1,74 @@
import time
import asyncio
import richka
import aiohttp
async def __download_range(session: aiohttp.ClientSession, url: str, start: int, end: int, destination: str) -> None:
richka.logger.info(f'Downloading part {start}-{end} of {url} to {destination}.')
headers = {**richka.HEADERS, **{'range': f'bytes={start}-{end}'}}
async with session.get(url, headers=headers) as response:
content = await response.read()
with open(destination, 'r+b') as f:
f.seek(start)
f.write(content)
richka.logger.info(f'Downloaded part {start}-{end} of {destination}.')
async def __download_single(session: aiohttp.ClientSession, url: str, destination: str) -> None:
richka.logger.info(f'Downloading {url} to {destination}.')
async with session.get(url, headers=richka.HEADERS) as response:
content = await response.read()
with open(destination, 'r+b') as f:
f.write(content)
richka.logger.info(f'Downloaded {url} to {destination}.')
async def download(url: str, destination: str) -> float:
async with aiohttp.ClientSession() as session:
# Get file size
async with session.head(url) as response:
file_size = int(response.headers.get('Content-Length', 0))
if not file_size or file_size / pow(1024, 2) <= 10:
if not file_size:
richka.logger.info(f'Failed to get file size, directly downloading {url}.')
else:
richka.logger.info(f"Downloading {url} ({file_size}) to {destination} with signle mode.")
# Create an empty file
with open(destination, 'wb') as f:
f.truncate(file_size)
# Start task
start_time = time.time()
await __download_single(session, url, destination)
end_time = time.time()
return end_time - start_time
richka.logger.info(f'Downloading {url} ({file_size}) to {destination} with slicing mode.')
# Calc slice size
part_size = file_size // richka.COROUTINE_LIMIT
# Create an empty file
with open(destination, 'wb') as f:
f.truncate(file_size)
# Create coroutine tasks
tasks = []
for i in range(richka.COROUTINE_LIMIT):
start = i * part_size
end = (start + part_size - 1) if i < richka.COROUTINE_LIMIT - 1 else (file_size - 1)
task = __download_range(session, url, start, end, destination)
tasks.append(task)
# Start all task
start_time = time.time()
await asyncio.gather(*tasks)
end_time = time.time()
return end_time - start_time

View File

@ -39,4 +39,4 @@ setup(
}, },
package_data={'': ['README.md']}, package_data={'': ['README.md']},
include_package_data=True, include_package_data=True,
zip_safe=False) zip_safe=False)