Source code for neuro_sdk._bucket_base

import abc
import enum
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import PurePosixPath
from typing import (
    Any,
    AsyncIterator,
    Awaitable,
    Callable,
    List,
    Mapping,
    Optional,
    Tuple,
    Union,
)

from yarl import URL

from ._rewrite import rewrite_module
from ._utils import AsyncContextManager


@rewrite_module
@dataclass(frozen=True)  # type: ignore
class BucketEntry(abc.ABC):
    key: str
    bucket: "Bucket"
    size: int
    created_at: Optional[datetime] = None
    modified_at: Optional[datetime] = None

    @property
    def name(self) -> str:
        return PurePosixPath(self.key).name

    @property
    def uri(self) -> URL:
        # Bucket key is an arbitrary string, so it can start with "/",
        # so we have to use this way to append it to bucket url
        return URL(str(self.bucket.uri) + "/" + self.key)

[docs] @abc.abstractmethod def is_file(self) -> bool: pass
[docs] @abc.abstractmethod def is_dir(self) -> bool: pass
@rewrite_module class BlobObject(BucketEntry): def is_file(self) -> bool: return not self.is_dir() def is_dir(self) -> bool: return self.key.endswith("/") and self.size == 0 @rewrite_module class BlobCommonPrefix(BucketEntry): size: int = 0 # This is "folder" analog in blobs # objects of this type will be only returned in # non recursive look-ups, to group multiple keys # in single entry. def is_file(self) -> bool: return False def is_dir(self) -> bool: return True @rewrite_module class BucketProvider(abc.ABC): """ Defines how to execute generic blob operations in a specific bucket provider """ bucket: "Bucket" @classmethod @abc.abstractmethod def create( cls, bucket: "Bucket", _get_credentials: Callable[[], Awaitable["BucketCredentials"]], ) -> AsyncContextManager["BucketProvider"]: pass @abc.abstractmethod def list_blobs( self, prefix: str, recursive: bool = False, limit: Optional[int] = None ) -> AsyncContextManager[AsyncIterator[BucketEntry]]: pass @abc.abstractmethod async def head_blob(self, key: str) -> BucketEntry: pass @abc.abstractmethod async def put_blob( self, key: str, body: Union[AsyncIterator[bytes], bytes], progress: Optional[Callable[[int], Awaitable[None]]] = None, ) -> None: pass @abc.abstractmethod def fetch_blob( self, key: str, offset: int = 0 ) -> AsyncContextManager[AsyncIterator[bytes]]: pass @abc.abstractmethod async def delete_blob( self, key: str, ) -> None: pass @abc.abstractmethod async def get_time_diff_to_local(self) -> Tuple[float, float]: pass @rewrite_module @dataclass(frozen=True) class Bucket: id: str owner: str cluster_name: str org_name: Optional[str] provider: "Bucket.Provider" created_at: datetime imported: bool public: bool = False name: Optional[str] = None @property def uri(self) -> URL: base = f"blob://{self.cluster_name}" if self.org_name: base += f"/{self.org_name}" return URL(f"{base}/{self.owner}/{self.name or self.id}") def get_key_for_uri(self, uri: URL) -> str: self_uris = [self.uri] if self.name: self_uris.append(self.uri.parent / self.id) uri_str = str(uri) for self_uri in self_uris: self_uri_str = str(self_uri) if uri_str.startswith(self_uri_str): return uri_str[len(self_uri_str) :].lstrip("/") raise ValueError(f"URI {uri} is not related to bucket {self.uri}")
[docs] class Provider(str, enum.Enum): AWS = "aws" MINIO = "minio" AZURE = "azure" GCP = "gcp" OPEN_STACK = "open_stack"
@rewrite_module @dataclass(frozen=True) class BucketUsage: total_bytes: int object_count: int @rewrite_module @dataclass(frozen=True) class BucketCredentials: bucket_id: str provider: "Bucket.Provider" credentials: Mapping[str, str] @rewrite_module @dataclass(frozen=True) class PersistentBucketCredentials: id: str owner: str cluster_name: str name: Optional[str] read_only: bool credentials: List[BucketCredentials] @rewrite_module class MeasureTimeDiffMixin: def __init__(self) -> None: self._min_time_diff: Optional[float] = 0 self._max_time_diff: Optional[float] = 0 def _wrap_api_call( self, _make_call: Callable[..., Awaitable[Any]], get_date: Callable[[Any], datetime], ) -> Callable[..., Awaitable[Any]]: @asynccontextmanager async def _ctx_manager(*args: Any, **kwargs: Any) -> AsyncIterator[Any]: yield await _make_call(*args, **kwargs) manager_wrapped = self._wrap_api_call_ctx_manager(_ctx_manager, get_date) async def _wrapper(*args: Any, **kwargs: Any) -> Any: async with manager_wrapped(*args, **kwargs) as res: return res return _wrapper def _wrap_api_call_ctx_manager( self, _make_call: Callable[..., AsyncContextManager[Any]], get_date: Callable[[Any], datetime], ) -> Callable[..., AsyncContextManager[Any]]: def _average(cur_approx: Optional[float], new_val: float) -> float: if cur_approx is None: return new_val return cur_approx * 0.9 + new_val * 0.1 @asynccontextmanager async def _wrapper(*args: Any, **kwargs: Any) -> AsyncIterator[Any]: before = time.time() async with _make_call(*args, **kwargs) as res: after = time.time() yield res try: server_dt = get_date(res) except Exception: pass else: server_time = server_dt.timestamp() # Remove 1 because server time has been truncated # and can be up to 1 second less than the actual value. self._min_time_diff = _average( cur_approx=self._min_time_diff, new_val=before - server_time - 1.0, ) self._max_time_diff = _average( cur_approx=self._min_time_diff, new_val=after - server_time, ) return _wrapper async def get_time_diff_to_local(self) -> Tuple[float, float]: if self._min_time_diff is None or self._max_time_diff is None: return 0, 0 return self._min_time_diff, self._max_time_diff