pydantic_ai.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import annotations
  2. from typing import Optional
  3. from functools import partial
  4. from dataclasses import dataclass, field
  5. from pydantic_ai.models import Model, KnownModelName, infer_model
  6. from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
  7. from ..client import AsyncClient
  8. @dataclass(init=False)
  9. class AIModel(OpenAIModel):
  10. """A model that uses the G4F API."""
  11. client: AsyncClient = field(repr=False)
  12. system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
  13. _model_name: str = field(repr=False)
  14. _provider: str = field(repr=False)
  15. _system: Optional[str] = field(repr=False)
  16. def __init__(
  17. self,
  18. model_name: str,
  19. provider: str | None = None,
  20. *,
  21. system_prompt_role: OpenAISystemPromptRole | None = None,
  22. system: str | None = 'openai',
  23. **kwargs
  24. ):
  25. """Initialize an AI model.
  26. Args:
  27. model_name: The name of the AI model to use. List of model names available
  28. [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
  29. (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
  30. system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
  31. In the future, this may be inferred from the model name.
  32. system: The model provider used, defaults to `openai`. This is for observability purposes, you must
  33. customize the `base_url` and `api_key` to use a different provider.
  34. """
  35. self._model_name = model_name
  36. self._provider = provider
  37. self.client = AsyncClient(provider=provider, **kwargs)
  38. self.system_prompt_role = system_prompt_role
  39. self._system = system
  40. def name(self) -> str:
  41. if self._provider:
  42. return f'g4f:{self._provider}:{self._model_name}'
  43. return f'g4f:{self._model_name}'
  44. def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
  45. if isinstance(model, Model):
  46. return model
  47. if model.startswith("g4f:"):
  48. model = model[4:]
  49. if ":" in model:
  50. provider, model = model.split(":", 1)
  51. return AIModel(model, provider=provider, api_key=api_key)
  52. return AIModel(model)
  53. return infer_model(model)
  54. def apply_patch(api_key: str | None = None):
  55. import pydantic_ai.models
  56. import pydantic_ai.models.openai
  57. pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
  58. pydantic_ai.models.AIModel = AIModel
  59. pydantic_ai.models.openai.NOT_GIVEN = None