Skip to content

Cache Decorator for Python Functions

third party module

diskcache 模块提供了一个 Cache 类,可以用于将函数的输出结果缓存到磁盘上。 joblib 模块提供了一个 Memory 类,可以用于内存缓存函数的输出结果。

这两个模块的使用方法都比较简单,可以根据实际需要选择使用。需要注意的是,缓存的数据类型应该是可序列化的,因为它们需要被序列化保存到磁盘或内存中。

python
# pip install diskcache
# poetry add diskcache

from diskcache import Cache

cache = Cache(directory='cache')

@cache.memoize(expire=3600)
def my_func(arg1, arg2):
    # some expensive computation here
    return result
python
from diskcache import Cache

cache_ = Cache(directory='cache')

def cache(cache_):
    def decorator(func):
        def wrapper(self, *args, **kwargs):
            cache_key = f"{self.__class__.__name__}_{func.__name__}_{args}_{kwargs}"
            if cache_key in cache_:
                return cache_[cache_key]
            result = func(self, *args, **kwargs)
            cache_[cache_key] = result
            return result
        return wrapper
    return decorator

def cache_async(cache_):
    def decorator(func):
        async def wrapper(self, *args, **kwargs):
            cache_key = f"{self.__class__.__name__}_{func.__name__}_{args}_{kwargs}"
            if cache_key in cache_:
                return cache_[cache_key]
            try:
                result = await func(self, *args, **kwargs)
            except Exception as e:
                # Handle the exception (log, re-raise, etc.)
                raise e
            cache_[cache_key] = result
            return result
        return wrapper
    return decorator

class C:

    @cache(cache_)
    def func(self):
        pass
python
from joblib import Memory

memory = Memory(location='cache')

@memory.cache
def my_func(arg1, arg2):
    # some expensive computation here
    return result

code

python
import os
import pickle
import hashlib
from functools import wraps
import time

def cache(namespace="unnamed", is_object=False, log=False, refresh=False, timeout=None):
    root_cache = "__cache_result__"
    # Create a root cache directory if it doesn't exist
    if not os.path.exists(root_cache):
        os.mkdir(root_cache)

    def gen_key(*args, **kwargs):
        # Serialize the arguments and keyword arguments
        serialized_args = pickle.dumps((args, kwargs))
        # Generate a SHA-256 hash of the serialized arguments and keyword arguments
        return hashlib.sha256(serialized_args).hexdigest()

    def is_cached(key):
        # Check if a cache file exists with the given key
        return os.path.exists(os.path.join(root_cache, key))

    def dump_cache(key, data):
        # Save the data to a cache file with the given key
        with open(os.path.join(root_cache, key), "wb") as fp:
            pickle.dump(data, fp)

    def load_cache(key):
        # Load the data from a cache file with the given key
        with open(os.path.join(root_cache, key), "rb") as fp:
            return pickle.load(fp)

    def inner_decorator(func):
        @wraps(func)
        def decorated(*args, **kwargs):
            # Exclude the first argument if is_object is True
            args_ = args[1:] if is_object else args
            # Generate a cache key using the namespace, function name, arguments, and keyword arguments
            cache_key = gen_key(namespace, func.__name__, args_, **kwargs)

            # If the cache is not being refreshed and the data is cached and not expired, return the cached data
            if not refresh and is_cached(cache_key):
                if timeout is None or (time.time() - os.path.getmtime(os.path.join(root_cache, cache_key))) < timeout:
                    if log:
                        print("[LOG] LOADING CACHE")
                    return load_cache(cache_key)

            # Otherwise, call the function to get the data and cache it
            result = func(*args, **kwargs)
            if log:
                print("[LOG] DUMPED CACHE")
            dump_cache(cache_key, result)
            return result

        return decorated

    return inner_decorator


# @cache(timeout=5)
# def get_html(url):
#     return time.time()

# print(get_html(''))
python
import os
import pickle
import hashlib
from functools import wraps
import time
import inspect

def cache(namespace="unnamed", is_object=False, log=False, refresh=False, timeout=None):
    root_cache = "__cache_result__"
    # Create a root cache directory if it doesn't exist
    if not os.path.exists(root_cache):
        os.mkdir(root_cache)

    def gen_key(*args, **kwargs):
        # Serialize the arguments and keyword arguments
        serialized_args = pickle.dumps((args, kwargs))
        # Generate a SHA-256 hash of the serialized arguments and keyword arguments
        return hashlib.sha256(serialized_args).hexdigest()

    def is_cached(key):
        # Check if a cache file exists with the given key
        return os.path.exists(os.path.join(root_cache, key))

    def dump_cache(key, data):
        # Save the data to a cache file with the given key atomically
        temp_path = os.path.join(root_cache, f"{key}.tmp")
        final_path = os.path.join(root_cache, key)
        with open(temp_path, "wb") as fp:
            pickle.dump(data, fp)
        os.replace(temp_path, final_path)  # Atomic move

    def load_cache(key):
        # Load the data from a cache file with the given key
        try:
            with open(os.path.join(root_cache, key), "rb") as fp:
                return pickle.load(fp)
        except (EOFError, pickle.UnpicklingError):
            # Handle corrupted cache file gracefully
            if log:
                print("[LOG] Corrupted cache file detected, ignoring.")
            return None

    def clear_expired_cache():
        # Remove expired cache files
        if timeout is not None:
            now = time.time()
            for filename in os.listdir(root_cache):
                file_path = os.path.join(root_cache, filename)
                if os.path.isfile(file_path):
                    file_age = now - os.path.getmtime(file_path)
                    if file_age >= timeout:
                        os.remove(file_path)
                        if log:
                            print(f"[LOG] Removed expired cache file: {filename}")

    def inner_decorator(func):
        @wraps(func)
        async def async_decorated(*args, **kwargs):
            clear_expired_cache()
            # Exclude the first argument if is_object is True
            args_ = args[1:] if is_object else args
            # Generate a cache key using the namespace, function name, arguments, and keyword arguments
            cache_key = gen_key(namespace, func.__name__, args_, **kwargs)

            # If the cache is not being refreshed and the data is cached and not expired, return the cached data
            if not refresh and is_cached(cache_key):
                if timeout is None or (time.time() - os.path.getmtime(os.path.join(root_cache, cache_key))) < timeout:
                    if log:
                        print("[LOG] LOADING CACHE")
                    cached_result = load_cache(cache_key)
                    if cached_result is not None:
                        return cached_result

            # Otherwise, call the function to get the data and cache it
            result = await func(*args, **kwargs)
            if log:
                print("[LOG] DUMPED CACHE")
            dump_cache(cache_key, result)
            return result

        @wraps(func)
        def sync_decorated(*args, **kwargs):
            clear_expired_cache()
            # Exclude the first argument if is_object is True
            args_ = args[1:] if is_object else args
            # Generate a cache key using the namespace, function name, arguments, and keyword arguments
            cache_key = gen_key(namespace, func.__name__, args_, **kwargs)

            # If the cache is not being refreshed and the data is cached and not expired, return the cached data
            if not refresh and is_cached(cache_key):
                if timeout is None or (time.time() - os.path.getmtime(os.path.join(root_cache, cache_key))) < timeout:
                    if log:
                        print("[LOG] LOADING CACHE")
                    cached_result = load_cache(cache_key)
                    if cached_result is not None:
                        return cached_result

            # Otherwise, call the function to get the data and cache it
            result = func(*args, **kwargs)
            if log:
                print("[LOG] DUMPED CACHE")
            dump_cache(cache_key, result)
            return result

        if inspect.iscoroutinefunction(func):
            return async_decorated
        else:
            return sync_decorated

    return inner_decorator
python
from cache import cache
import requests
import time

class Timer:
    def __enter__(self):
        self.start_time = time.time()
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.elapsed_time = time.time() - self.start_time
        
    def __str__(self):
        return f"Elapsed time: {self.elapsed_time:.3f} seconds"

@cache()
def get_ip():
    res = requests.get("https://www.gutenberg.org/files/2600/2600-0.txt")
    return res.text

with Timer() as timer:
    res = get_ip()
    print(res)

print(timer)

other

python
def json_dump(filename):
    import json
    def decorator(func):
        def wrapper(*args, **kwargs):
            result = func(*args, **kwargs)
            with open(filename, 'w', encoding='utf-8') as fp:
                json.dump(result, fp, indent=2, ensure_ascii=False)
            return result
        return wrapper
    return decorator

@json_dump("cache.json")
def func():
    pass

Released under the MIT License.