__init__.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. from __future__ import annotations
  2. import logging
  3. import json
  4. import uvicorn
  5. import secrets
  6. import os
  7. from fastapi import FastAPI, Response, Request
  8. from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
  9. from fastapi.exceptions import RequestValidationError
  10. from fastapi.security import APIKeyHeader
  11. from starlette.exceptions import HTTPException
  12. from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
  13. from fastapi.encoders import jsonable_encoder
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from starlette.responses import FileResponse
  16. from pydantic import BaseModel
  17. from typing import Union, Optional
  18. import g4f
  19. import g4f.debug
  20. from g4f.client import AsyncClient, ChatCompletion
  21. from g4f.providers.response import BaseConversation
  22. from g4f.client.helper import filter_none
  23. from g4f.image import is_accepted_format, images_dir
  24. from g4f.typing import Messages
  25. from g4f.cookies import read_cookie_files
  26. logger = logging.getLogger(__name__)
  27. def create_app(g4f_api_key: str = None):
  28. app = FastAPI()
  29. # Add CORS middleware
  30. app.add_middleware(
  31. CORSMiddleware,
  32. allow_origin_regex=".*",
  33. allow_credentials=True,
  34. allow_methods=["*"],
  35. allow_headers=["*"],
  36. )
  37. api = Api(app, g4f_api_key=g4f_api_key)
  38. api.register_routes()
  39. api.register_authorization()
  40. api.register_validation_exception_handler()
  41. # Read cookie files if not ignored
  42. if not AppConfig.ignore_cookie_files:
  43. read_cookie_files()
  44. return app
  45. def create_app_debug(g4f_api_key: str = None):
  46. g4f.debug.logging = True
  47. return create_app(g4f_api_key)
  48. class ChatCompletionsConfig(BaseModel):
  49. messages: Messages
  50. model: str
  51. provider: Optional[str] = None
  52. stream: bool = False
  53. temperature: Optional[float] = None
  54. max_tokens: Optional[int] = None
  55. stop: Union[list[str], str, None] = None
  56. api_key: Optional[str] = None
  57. web_search: Optional[bool] = None
  58. proxy: Optional[str] = None
  59. conversation_id: str = None
  60. class ImageGenerationConfig(BaseModel):
  61. prompt: str
  62. model: Optional[str] = None
  63. provider: Optional[str] = None
  64. response_format: str = "url"
  65. api_key: Optional[str] = None
  66. proxy: Optional[str] = None
  67. class AppConfig:
  68. ignored_providers: Optional[list[str]] = None
  69. g4f_api_key: Optional[str] = None
  70. ignore_cookie_files: bool = False
  71. model: str = None,
  72. provider: str = None
  73. image_provider: str = None
  74. proxy: str = None
  75. @classmethod
  76. def set_config(cls, **data):
  77. for key, value in data.items():
  78. setattr(cls, key, value)
  79. list_ignored_providers: list[str] = None
  80. def set_list_ignored_providers(ignored: list[str]):
  81. global list_ignored_providers
  82. list_ignored_providers = ignored
  83. class Api:
  84. def __init__(self, app: FastAPI, g4f_api_key=None) -> None:
  85. self.app = app
  86. self.client = AsyncClient()
  87. self.g4f_api_key = g4f_api_key
  88. self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
  89. self.conversations: dict[str, dict[str, BaseConversation]] = {}
  90. def register_authorization(self):
  91. @self.app.middleware("http")
  92. async def authorization(request: Request, call_next):
  93. if self.g4f_api_key and request.url.path in ["/v1/chat/completions", "/v1/completions", "/v1/images/generate"]:
  94. try:
  95. user_g4f_api_key = await self.get_g4f_api_key(request)
  96. except HTTPException as e:
  97. if e.status_code == 403:
  98. return JSONResponse(
  99. status_code=HTTP_401_UNAUTHORIZED,
  100. content=jsonable_encoder({"detail": "G4F API key required"}),
  101. )
  102. if not secrets.compare_digest(self.g4f_api_key, user_g4f_api_key):
  103. return JSONResponse(
  104. status_code=HTTP_403_FORBIDDEN,
  105. content=jsonable_encoder({"detail": "Invalid G4F API key"}),
  106. )
  107. response = await call_next(request)
  108. return response
  109. def register_validation_exception_handler(self):
  110. @self.app.exception_handler(RequestValidationError)
  111. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  112. details = exc.errors()
  113. modified_details = []
  114. for error in details:
  115. modified_details.append({
  116. "loc": error["loc"],
  117. "message": error["msg"],
  118. "type": error["type"],
  119. })
  120. return JSONResponse(
  121. status_code=HTTP_422_UNPROCESSABLE_ENTITY,
  122. content=jsonable_encoder({"detail": modified_details}),
  123. )
  124. def register_routes(self):
  125. @self.app.get("/")
  126. async def read_root():
  127. return RedirectResponse("/v1", 302)
  128. @self.app.get("/v1")
  129. async def read_root_v1():
  130. return HTMLResponse('g4f API: Go to '
  131. '<a href="/v1/models">models</a>, '
  132. '<a href="/v1/chat/completions">chat/completions</a>, or '
  133. '<a href="/v1/images/generate">images/generate</a>.')
  134. @self.app.get("/v1/models")
  135. async def models():
  136. model_list = dict(
  137. (model, g4f.models.ModelUtils.convert[model])
  138. for model in g4f.Model.__all__()
  139. )
  140. model_list = [{
  141. 'id': model_id,
  142. 'object': 'model',
  143. 'created': 0,
  144. 'owned_by': model.base_provider
  145. } for model_id, model in model_list.items()]
  146. return JSONResponse(model_list)
  147. @self.app.get("/v1/models/{model_name}")
  148. async def model_info(model_name: str):
  149. try:
  150. model_info = g4f.models.ModelUtils.convert[model_name]
  151. return JSONResponse({
  152. 'id': model_name,
  153. 'object': 'model',
  154. 'created': 0,
  155. 'owned_by': model_info.base_provider
  156. })
  157. except:
  158. return JSONResponse({"error": "The model does not exist."})
  159. @self.app.post("/v1/chat/completions")
  160. async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
  161. try:
  162. config.provider = provider if config.provider is None else config.provider
  163. if config.provider is None:
  164. config.provider = AppConfig.provider
  165. if config.api_key is None and request is not None:
  166. auth_header = request.headers.get("Authorization")
  167. if auth_header is not None:
  168. api_key = auth_header.split(None, 1)[-1]
  169. if api_key and api_key != "Bearer":
  170. config.api_key = api_key
  171. conversation = return_conversation = None
  172. if config.conversation_id is not None and config.provider is not None:
  173. return_conversation = True
  174. if config.conversation_id in self.conversations:
  175. if config.provider in self.conversations[config.conversation_id]:
  176. conversation = self.conversations[config.conversation_id][config.provider]
  177. # Create the completion response
  178. response = self.client.chat.completions.create(
  179. **filter_none(
  180. **{
  181. "model": AppConfig.model,
  182. "provider": AppConfig.provider,
  183. "proxy": AppConfig.proxy,
  184. **config.dict(exclude_none=True),
  185. **{
  186. "conversation_id": None,
  187. "return_conversation": return_conversation,
  188. "conversation": conversation
  189. }
  190. },
  191. ignored=AppConfig.ignored_providers
  192. ),
  193. )
  194. if not config.stream:
  195. response: ChatCompletion = await response
  196. return JSONResponse(response.to_json())
  197. async def streaming():
  198. try:
  199. async for chunk in response:
  200. if isinstance(chunk, BaseConversation):
  201. if config.conversation_id is not None and config.provider is not None:
  202. if config.conversation_id not in self.conversations:
  203. self.conversations[config.conversation_id] = {}
  204. self.conversations[config.conversation_id][config.provider] = chunk
  205. else:
  206. yield f"data: {json.dumps(chunk.to_json())}\n\n"
  207. except GeneratorExit:
  208. pass
  209. except Exception as e:
  210. logger.exception(e)
  211. yield f'data: {format_exception(e, config)}\n\n'
  212. yield "data: [DONE]\n\n"
  213. return StreamingResponse(streaming(), media_type="text/event-stream")
  214. except Exception as e:
  215. logger.exception(e)
  216. return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
  217. @self.app.post("/v1/images/generate")
  218. @self.app.post("/v1/images/generations")
  219. async def generate_image(config: ImageGenerationConfig, request: Request):
  220. if config.api_key is None:
  221. auth_header = request.headers.get("Authorization")
  222. if auth_header is not None:
  223. api_key = auth_header.split(None, 1)[-1]
  224. if api_key and api_key != "Bearer":
  225. config.api_key = api_key
  226. try:
  227. response = await self.client.images.generate(
  228. prompt=config.prompt,
  229. model=config.model,
  230. provider=AppConfig.image_provider if config.provider is None else config.provider,
  231. **filter_none(
  232. response_format = config.response_format,
  233. api_key = config.api_key,
  234. proxy = config.proxy
  235. )
  236. )
  237. for image in response.data:
  238. if hasattr(image, "url") and image.url.startswith("/"):
  239. image.url = f"{request.base_url}{image.url.lstrip('/')}"
  240. return JSONResponse(response.to_json())
  241. except Exception as e:
  242. logger.exception(e)
  243. return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json")
  244. @self.app.post("/v1/completions")
  245. async def completions():
  246. return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
  247. @self.app.get("/images/{filename}")
  248. async def get_image(filename):
  249. target = os.path.join(images_dir, filename)
  250. if not os.path.isfile(target):
  251. return Response(status_code=404)
  252. with open(target, "rb") as f:
  253. content_type = is_accepted_format(f.read(12))
  254. return FileResponse(target, media_type=content_type)
  255. def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str:
  256. last_provider = {} if not image else g4f.get_last_provider(True)
  257. provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider
  258. model = AppConfig.model if config.model is None else config.model
  259. return json.dumps({
  260. "error": {"message": f"{e.__class__.__name__}: {e}"},
  261. "model": last_provider.get("model") if model is None else model,
  262. **filter_none(
  263. provider=last_provider.get("name") if provider is None else provider
  264. )
  265. })
  266. def run_api(
  267. host: str = '0.0.0.0',
  268. port: int = 1337,
  269. bind: str = None,
  270. debug: bool = False,
  271. workers: int = None,
  272. use_colors: bool = None,
  273. reload: bool = False
  274. ) -> None:
  275. print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
  276. if use_colors is None:
  277. use_colors = debug
  278. if bind is not None:
  279. host, port = bind.split(":")
  280. uvicorn.run(
  281. f"g4f.api:create_app{'_debug' if debug else ''}",
  282. host=host,
  283. port=int(port),
  284. workers=workers,
  285. use_colors=use_colors,
  286. factory=True,
  287. reload=reload
  288. )