123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- from __future__ import annotations
- import asyncio
- from asyncio import AbstractEventLoop
- from concurrent.futures import ThreadPoolExecutor
- from abc import abstractmethod
- import json
- from inspect import signature, Parameter
- from typing import Optional, _GenericAlias
- from pathlib import Path
- try:
- from types import NoneType
- except ImportError:
- NoneType = type(None)
- from ..typing import CreateResult, AsyncResult, Messages
- from .types import BaseProvider
- from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
- from .response import BaseConversation, AuthResult
- from .helper import concat_chunks, async_concat_chunks
- from ..cookies import get_cookies_dir
- from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
- from .. import debug
- SAFE_PARAMETERS = [
- "model", "messages", "stream", "timeout",
- "proxy", "images", "response_format",
- "prompt", "negative_prompt", "tools", "conversation",
- "history_disabled", "auto_continue",
- "temperature", "top_k", "top_p",
- "frequency_penalty", "presence_penalty",
- "max_tokens", "max_new_tokens", "stop",
- "api_key", "api_base", "seed", "width", "height",
- "proof_token", "max_retries", "web_search",
- "guidance_scale", "num_inference_steps", "randomize_seed",
- ]
- BASIC_PARAMETERS = {
- "provider": None,
- "model": "",
- "messages": [],
- "stream": False,
- "timeout": 0,
- "response_format": None,
- "max_tokens": None,
- "stop": None,
- }
- PARAMETER_EXAMPLES = {
- "proxy": "http://user:password@127.0.0.1:3128",
- "temperature": 1,
- "top_k": 1,
- "top_p": 1,
- "frequency_penalty": 1,
- "presence_penalty": 1,
- "messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
- "images": [["data:image/jpeg;base64,...", "filename.jpg"]],
- "response_format": {"type": "json_object"},
- "conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
- "max_new_tokens": 1024,
- "max_tokens": 4096,
- "seed": 42,
- }
- class AbstractProvider(BaseProvider):
- @classmethod
- @abstractmethod
- def create_completion(
- cls,
- model: str,
- messages: Messages,
- stream: bool,
- **kwargs
- ) -> CreateResult:
- """
- Create a completion with the given parameters.
- Args:
- model (str): The model to use.
- messages (Messages): The messages to process.
- stream (bool): Whether to use streaming.
- **kwargs: Additional keyword arguments.
- Returns:
- CreateResult: The result of the creation process.
- """
- raise NotImplementedError()
- @classmethod
- async def create_async(
- cls,
- model: str,
- messages: Messages,
- *,
- timeout: int = None,
- loop: AbstractEventLoop = None,
- executor: ThreadPoolExecutor = None,
- **kwargs
- ) -> str:
- """
- Asynchronously creates a result based on the given model and messages.
- Args:
- cls (type): The class on which this method is called.
- model (str): The model to use for creation.
- messages (Messages): The messages to process.
- loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
- executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
- **kwargs: Additional keyword arguments.
- Returns:
- str: The created result as a string.
- """
- loop = asyncio.get_running_loop() if loop is None else loop
- def create_func() -> str:
- return concat_chunks(cls.create_completion(model, messages, **kwargs))
- return await asyncio.wait_for(
- loop.run_in_executor(executor, create_func),
- timeout=timeout
- )
- @classmethod
- def get_create_function(cls) -> callable:
- return cls.create_completion
- @classmethod
- def get_async_create_function(cls) -> callable:
- return cls.create_async
- @classmethod
- def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
- params = {name: parameter for name, parameter in signature(
- cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
- cls.create_async if issubclass(cls, AsyncProvider) else
- cls.create_completion
- ).parameters.items() if name in SAFE_PARAMETERS
- and (name != "stream" or cls.supports_stream)}
- if as_json:
- def get_type_as_var(annotation: type, key: str, default):
- if key in PARAMETER_EXAMPLES:
- if key == "messages" and not cls.supports_system_message:
- return [PARAMETER_EXAMPLES[key][-1]]
- return PARAMETER_EXAMPLES[key]
- if isinstance(annotation, type):
- if issubclass(annotation, int):
- return 0
- elif issubclass(annotation, float):
- return 0.0
- elif issubclass(annotation, bool):
- return False
- elif issubclass(annotation, str):
- return ""
- elif issubclass(annotation, dict):
- return {}
- elif issubclass(annotation, list):
- return []
- elif issubclass(annotation, BaseConversation):
- return {}
- elif issubclass(annotation, NoneType):
- return {}
- elif annotation is None:
- return None
- elif annotation == "str" or annotation == "list[str]":
- return default
- elif isinstance(annotation, _GenericAlias):
- if annotation.__origin__ is Optional:
- return get_type_as_var(annotation.__args__[0])
- else:
- return str(annotation)
- return { name: (
- param.default
- if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
- else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
- ) for name, param in {
- **BASIC_PARAMETERS,
- **params,
- **{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
- }.items()}
- return params
- @classmethod
- @property
- def params(cls) -> str:
- """
- Returns the parameters supported by the provider.
- Args:
- cls (type): The class on which this property is called.
- Returns:
- str: A string listing the supported parameters.
- """
- def get_type_name(annotation: type) -> str:
- return getattr(annotation, "__name__", str(annotation)) if annotation is not Parameter.empty else ""
- args = ""
- for name, param in cls.get_parameters().items():
- args += f"\n {name}"
- args += f": {get_type_name(param.annotation)}"
- default_value = getattr(cls, "default_model", "") if name == "model" else param.default
- default_value = f'"{default_value}"' if isinstance(default_value, str) else default_value
- args += f" = {default_value}" if param.default is not Parameter.empty else ""
- args += ","
- return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
- class AsyncProvider(AbstractProvider):
- """
- Provides asynchronous functionality for creating completions.
- """
- @classmethod
- def create_completion(
- cls,
- model: str,
- messages: Messages,
- stream: bool = False,
- **kwargs
- ) -> CreateResult:
- """
- Creates a completion result synchronously.
- Args:
- cls (type): The class on which this method is called.
- model (str): The model to use for creation.
- messages (Messages): The messages to process.
- stream (bool): Indicates whether to stream the results. Defaults to False.
- loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
- **kwargs: Additional keyword arguments.
- Returns:
- CreateResult: The result of the completion creation.
- """
- get_running_loop(check_nested=False)
- yield asyncio.run(cls.create_async(model, messages, **kwargs))
- @staticmethod
- @abstractmethod
- async def create_async(
- model: str,
- messages: Messages,
- **kwargs
- ) -> str:
- """
- Abstract method for creating asynchronous results.
- Args:
- model (str): The model to use for creation.
- messages (Messages): The messages to process.
- **kwargs: Additional keyword arguments.
- Raises:
- NotImplementedError: If this method is not overridden in derived classes.
- Returns:
- str: The created result as a string.
- """
- raise NotImplementedError()
- @classmethod
- def get_create_function(cls) -> callable:
- return cls.create_completion
- @classmethod
- def get_async_create_function(cls) -> callable:
- return cls.create_async
- class AsyncGeneratorProvider(AbstractProvider):
- """
- Provides asynchronous generator functionality for streaming results.
- """
- supports_stream = True
- @classmethod
- def create_completion(
- cls,
- model: str,
- messages: Messages,
- stream: bool = True,
- **kwargs
- ) -> CreateResult:
- """
- Creates a streaming completion result synchronously.
- Args:
- cls (type): The class on which this method is called.
- model (str): The model to use for creation.
- messages (Messages): The messages to process.
- stream (bool): Indicates whether to stream the results. Defaults to True.
- loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
- **kwargs: Additional keyword arguments.
- Returns:
- CreateResult: The result of the streaming completion creation.
- """
- return to_sync_generator(
- cls.create_async_generator(model, messages, stream=stream, **kwargs),
- stream=stream
- )
- @staticmethod
- @abstractmethod
- async def create_async_generator(
- model: str,
- messages: Messages,
- stream: bool = True,
- **kwargs
- ) -> AsyncResult:
- """
- Abstract method for creating an asynchronous generator.
- Args:
- model (str): The model to use for creation.
- messages (Messages): The messages to process.
- stream (bool): Indicates whether to stream the results. Defaults to True.
- **kwargs: Additional keyword arguments.
- Raises:
- NotImplementedError: If this method is not overridden in derived classes.
- Returns:
- AsyncResult: An asynchronous generator yielding results.
- """
- raise NotImplementedError()
- @classmethod
- def get_create_function(cls) -> callable:
- return cls.create_completion
- @classmethod
- def get_async_create_function(cls) -> callable:
- return cls.create_async_generator
- class ProviderModelMixin:
- default_model: str = None
- models: list[str] = []
- model_aliases: dict[str, str] = {}
- image_models: list = None
- last_model: str = None
- @classmethod
- def get_models(cls, **kwargs) -> list[str]:
- if not cls.models and cls.default_model is not None:
- return [cls.default_model]
- return cls.models
- @classmethod
- def get_model(cls, model: str, **kwargs) -> str:
- if not model and cls.default_model is not None:
- model = cls.default_model
- elif model in cls.model_aliases:
- model = cls.model_aliases[model]
- else:
- if model not in cls.get_models(**kwargs) and cls.models:
- raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
- cls.last_model = model
- debug.last_model = model
- return model
- class RaiseErrorMixin():
- @staticmethod
- def raise_error(data: dict):
- if "error_message" in data:
- raise ResponseError(data["error_message"])
- elif "error" in data:
- if "code" in data["error"]:
- raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
- elif "message" in data["error"]:
- raise ResponseError(data["error"]["message"])
- else:
- raise ResponseError(data["error"])
- class AsyncAuthedProvider(AsyncGeneratorProvider):
- @classmethod
- async def on_auth_async(cls, **kwargs) -> AuthResult:
- if "api_key" not in kwargs:
- raise MissingAuthError(f"API key is required for {cls.__name__}")
- return AuthResult()
- @classmethod
- def on_auth(cls, **kwargs) -> AuthResult:
- auth_result = cls.on_auth_async(**kwargs)
- if hasattr(auth_result, "__aiter__"):
- return to_sync_generator(auth_result)
- return asyncio.run(auth_result)
- @classmethod
- def get_create_function(cls) -> callable:
- return cls.create_completion
- @classmethod
- def get_async_create_function(cls) -> callable:
- return cls.create_async_generator
- @classmethod
- def get_cache_file(cls) -> Path:
- return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
- @classmethod
- def create_completion(
- cls,
- model: str,
- messages: Messages,
- **kwargs
- ) -> CreateResult:
- try:
- auth_result = AuthResult()
- cache_file = cls.get_cache_file()
- if cache_file.exists():
- with cache_file.open("r") as f:
- auth_result = AuthResult(**json.load(f))
- else:
- auth_result = cls.on_auth(**kwargs)
- try:
- for chunk in auth_result:
- if hasattr(chunk, "get_dict"):
- auth_result = chunk
- else:
- yield chunk
- except TypeError:
- pass
- yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
- except (MissingAuthError, NoValidHarFileError):
- auth_result = cls.on_auth(**kwargs)
- try:
- for chunk in auth_result:
- if hasattr(chunk, "get_dict"):
- auth_result = chunk
- else:
- yield chunk
- except TypeError:
- pass
- yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
- finally:
- if hasattr(auth_result, "get_dict"):
- data = auth_result.get_dict()
- cache_file.parent.mkdir(parents=True, exist_ok=True)
- cache_file.write_text(json.dumps(data))
- elif cache_file.exists():
- cache_file.unlink()
- @classmethod
- async def create_async_generator(
- cls,
- model: str,
- messages: Messages,
- **kwargs
- ) -> AsyncResult:
- try:
- auth_result = AuthResult()
- cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
- if cache_file.exists():
- with cache_file.open("r") as f:
- auth_result = AuthResult(**json.load(f))
- else:
- auth_result = cls.on_auth_async(**kwargs)
- if hasattr(auth_result, "_aiter__"):
- async for chunk in auth_result:
- if isinstance(chunk, AsyncResult):
- auth_result = chunk
- else:
- yield chunk
- else:
- auth_result = await auth_result
- response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
- async for chunk in response:
- yield chunk
- except (MissingAuthError, NoValidHarFileError):
- if cache_file.exists():
- cache_file.unlink()
- auth_result = cls.on_auth_async(**kwargs)
- if hasattr(auth_result, "_aiter__"):
- async for chunk in auth_result:
- if isinstance(chunk, AsyncResult):
- auth_result = chunk
- else:
- yield chunk
- else:
- auth_result = await auth_result
- response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
- async for chunk in response:
- yield chunk
- finally:
- if hasattr(auth_result, "get_dict"):
- cache_file.parent.mkdir(parents=True, exist_ok=True)
- cache_file.write_text(json.dumps(auth_result.get_dict()))
- elif cache_file.exists():
- cache_file.unlink()
|