Edit on GitHub

common.lib.llm

  1import json
  2
  3from typing import List, Optional, Union
  4from pydantic import SecretStr
  5from langchain_core.messages import BaseMessage
  6from langchain_core.language_models.chat_models import BaseChatModel
  7from langchain_anthropic import ChatAnthropic
  8from langchain_google_genai import ChatGoogleGenerativeAI
  9from langchain_ollama import ChatOllama
 10from langchain_openai import ChatOpenAI
 11from langchain_mistralai import ChatMistralAI
 12
 13class LLMAdapter:
 14    def __init__(
 15        self,
 16        provider: str,
 17        model: str,
 18        api_key: Optional[str] = None,
 19        base_url: Optional[str] = None,
 20        temperature: float = 0.0,
 21    ):
 22        """
 23        provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral'
 24        model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.)
 25        api_key: API key if required (OpenAI, Claude, Google, Mistral)
 26        base_url: for local models or Mistral custom endpoints
 27        """
 28        self.provider = provider.lower()
 29        self.model = model
 30        self.api_key = api_key
 31        self.base_url = base_url
 32        self.temperature = temperature
 33        self.llm: BaseChatModel = self._load_llm()
 34
 35
 36    def _load_llm(self) -> BaseChatModel:
 37        if self.provider == "openai":
 38            kwargs = {}
 39            if "o3" not in self.model:
 40                kwargs["temperature"] = self.temperature # temperature not supported for all models
 41            return ChatOpenAI(
 42                model=self.model,
 43                api_key=SecretStr(self.api_key),
 44                base_url=self.base_url or "https://api.openai.com/v1",
 45                **kwargs
 46            )
 47        elif self.provider == "google":
 48            return ChatGoogleGenerativeAI(
 49                model=self.model,
 50                temperature=self.temperature,
 51                google_api_key=self.api_key
 52            )
 53        elif self.provider == "anthropic":
 54            return ChatAnthropic(
 55                model_name=self.model,
 56                temperature=self.temperature,
 57                api_key=SecretStr(self.api_key),
 58                timeout=100,
 59                stop=None
 60            )
 61        elif self.provider == "mistral":
 62            return ChatMistralAI(
 63                model_name=self.model,
 64                temperature=self.temperature,
 65                api_key=SecretStr(self.api_key),
 66                base_url=self.base_url  # Optional override
 67            )
 68        elif self.provider == "ollama":
 69            ollama_adapter = ChatOllama(
 70                model=self.model,
 71                temperature=self.temperature,
 72                base_url=self.base_url or "http://localhost:11434"
 73            )
 74            self.model = ollama_adapter.model
 75            return ollama_adapter
 76        elif self.provider in {"vllm", "lmstudio"}:
 77            # OpenAI-compatible local servers
 78            if self.provider == "lmstudio" and not self.api_key:
 79                self.api_key = "lm-studio"
 80            return ChatOpenAI(
 81                model=self.model,
 82                temperature=self.temperature,
 83                api_key=SecretStr(self.api_key),
 84                base_url=self.base_url
 85            )
 86        else:
 87            raise ValueError(f"Unsupported LLM provider: {self.provider}")
 88
 89    def text_generation(
 90        self,
 91        messages: Union[str, List[BaseMessage]],
 92        system_prompt: Optional[str] = None,
 93    ) -> str:
 94        """
 95        Supports string input or LangChain message list.
 96        """
 97        if isinstance(messages, str):
 98            from langchain_core.messages import HumanMessage, SystemMessage
 99            lc_messages = []
100            if system_prompt:
101                lc_messages.append(SystemMessage(content=system_prompt))
102            lc_messages.append(HumanMessage(content=messages))
103        else:
104            lc_messages = messages
105
106        try:
107            response = self.llm.invoke(lc_messages).content
108        except Exception as e:
109            raise e
110
111        return response
112
113    @staticmethod
114    def get_model_options(config) -> dict:
115        """
116        Returns model choice options for UserInput
117
118        :param config:  Configuration reader
119        """
120
121        models = LLMAdapter.get_models(config)
122        options = {model_id: model_values["name"] for model_id, model_values in models.items()}
123        return options
124
125    @staticmethod
126    def get_models(config) -> dict:
127        """
128        Returns a dict with LLM models supported by 4CAT, either through an API or as a local option.
129        Make sure to keep up-to-date!
130
131        :param config:  Configuration reader
132
133        :returns dict, A dict with model IDs as keys and details as values
134        """
135
136        with (
137            config.get("PATH_ROOT")
138            .joinpath("common/assets/llms.json")
139            .open() as available_models
140        ):
141            available_models = json.loads(available_models.read())
142        return available_models
class LLMAdapter:
 14class LLMAdapter:
 15    def __init__(
 16        self,
 17        provider: str,
 18        model: str,
 19        api_key: Optional[str] = None,
 20        base_url: Optional[str] = None,
 21        temperature: float = 0.0,
 22    ):
 23        """
 24        provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral'
 25        model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.)
 26        api_key: API key if required (OpenAI, Claude, Google, Mistral)
 27        base_url: for local models or Mistral custom endpoints
 28        """
 29        self.provider = provider.lower()
 30        self.model = model
 31        self.api_key = api_key
 32        self.base_url = base_url
 33        self.temperature = temperature
 34        self.llm: BaseChatModel = self._load_llm()
 35
 36
 37    def _load_llm(self) -> BaseChatModel:
 38        if self.provider == "openai":
 39            kwargs = {}
 40            if "o3" not in self.model:
 41                kwargs["temperature"] = self.temperature # temperature not supported for all models
 42            return ChatOpenAI(
 43                model=self.model,
 44                api_key=SecretStr(self.api_key),
 45                base_url=self.base_url or "https://api.openai.com/v1",
 46                **kwargs
 47            )
 48        elif self.provider == "google":
 49            return ChatGoogleGenerativeAI(
 50                model=self.model,
 51                temperature=self.temperature,
 52                google_api_key=self.api_key
 53            )
 54        elif self.provider == "anthropic":
 55            return ChatAnthropic(
 56                model_name=self.model,
 57                temperature=self.temperature,
 58                api_key=SecretStr(self.api_key),
 59                timeout=100,
 60                stop=None
 61            )
 62        elif self.provider == "mistral":
 63            return ChatMistralAI(
 64                model_name=self.model,
 65                temperature=self.temperature,
 66                api_key=SecretStr(self.api_key),
 67                base_url=self.base_url  # Optional override
 68            )
 69        elif self.provider == "ollama":
 70            ollama_adapter = ChatOllama(
 71                model=self.model,
 72                temperature=self.temperature,
 73                base_url=self.base_url or "http://localhost:11434"
 74            )
 75            self.model = ollama_adapter.model
 76            return ollama_adapter
 77        elif self.provider in {"vllm", "lmstudio"}:
 78            # OpenAI-compatible local servers
 79            if self.provider == "lmstudio" and not self.api_key:
 80                self.api_key = "lm-studio"
 81            return ChatOpenAI(
 82                model=self.model,
 83                temperature=self.temperature,
 84                api_key=SecretStr(self.api_key),
 85                base_url=self.base_url
 86            )
 87        else:
 88            raise ValueError(f"Unsupported LLM provider: {self.provider}")
 89
 90    def text_generation(
 91        self,
 92        messages: Union[str, List[BaseMessage]],
 93        system_prompt: Optional[str] = None,
 94    ) -> str:
 95        """
 96        Supports string input or LangChain message list.
 97        """
 98        if isinstance(messages, str):
 99            from langchain_core.messages import HumanMessage, SystemMessage
100            lc_messages = []
101            if system_prompt:
102                lc_messages.append(SystemMessage(content=system_prompt))
103            lc_messages.append(HumanMessage(content=messages))
104        else:
105            lc_messages = messages
106
107        try:
108            response = self.llm.invoke(lc_messages).content
109        except Exception as e:
110            raise e
111
112        return response
113
114    @staticmethod
115    def get_model_options(config) -> dict:
116        """
117        Returns model choice options for UserInput
118
119        :param config:  Configuration reader
120        """
121
122        models = LLMAdapter.get_models(config)
123        options = {model_id: model_values["name"] for model_id, model_values in models.items()}
124        return options
125
126    @staticmethod
127    def get_models(config) -> dict:
128        """
129        Returns a dict with LLM models supported by 4CAT, either through an API or as a local option.
130        Make sure to keep up-to-date!
131
132        :param config:  Configuration reader
133
134        :returns dict, A dict with model IDs as keys and details as values
135        """
136
137        with (
138            config.get("PATH_ROOT")
139            .joinpath("common/assets/llms.json")
140            .open() as available_models
141        ):
142            available_models = json.loads(available_models.read())
143        return available_models
LLMAdapter( provider: str, model: str, api_key: Optional[str] = None, base_url: Optional[str] = None, temperature: float = 0.0)
15    def __init__(
16        self,
17        provider: str,
18        model: str,
19        api_key: Optional[str] = None,
20        base_url: Optional[str] = None,
21        temperature: float = 0.0,
22    ):
23        """
24        provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral'
25        model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.)
26        api_key: API key if required (OpenAI, Claude, Google, Mistral)
27        base_url: for local models or Mistral custom endpoints
28        """
29        self.provider = provider.lower()
30        self.model = model
31        self.api_key = api_key
32        self.base_url = base_url
33        self.temperature = temperature
34        self.llm: BaseChatModel = self._load_llm()

provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral' model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.) api_key: API key if required (OpenAI, Claude, Google, Mistral) base_url: for local models or Mistral custom endpoints

provider
model
api_key
base_url
temperature
llm: langchain_core.language_models.chat_models.BaseChatModel
def text_generation( self, messages: Union[str, List[langchain_core.messages.base.BaseMessage]], system_prompt: Optional[str] = None) -> str:
 90    def text_generation(
 91        self,
 92        messages: Union[str, List[BaseMessage]],
 93        system_prompt: Optional[str] = None,
 94    ) -> str:
 95        """
 96        Supports string input or LangChain message list.
 97        """
 98        if isinstance(messages, str):
 99            from langchain_core.messages import HumanMessage, SystemMessage
100            lc_messages = []
101            if system_prompt:
102                lc_messages.append(SystemMessage(content=system_prompt))
103            lc_messages.append(HumanMessage(content=messages))
104        else:
105            lc_messages = messages
106
107        try:
108            response = self.llm.invoke(lc_messages).content
109        except Exception as e:
110            raise e
111
112        return response

Supports string input or LangChain message list.

@staticmethod
def get_model_options(config) -> dict:
114    @staticmethod
115    def get_model_options(config) -> dict:
116        """
117        Returns model choice options for UserInput
118
119        :param config:  Configuration reader
120        """
121
122        models = LLMAdapter.get_models(config)
123        options = {model_id: model_values["name"] for model_id, model_values in models.items()}
124        return options

Returns model choice options for UserInput

Parameters
  • config: Configuration reader
@staticmethod
def get_models(config) -> dict:
126    @staticmethod
127    def get_models(config) -> dict:
128        """
129        Returns a dict with LLM models supported by 4CAT, either through an API or as a local option.
130        Make sure to keep up-to-date!
131
132        :param config:  Configuration reader
133
134        :returns dict, A dict with model IDs as keys and details as values
135        """
136
137        with (
138            config.get("PATH_ROOT")
139            .joinpath("common/assets/llms.json")
140            .open() as available_models
141        ):
142            available_models = json.loads(available_models.read())
143        return available_models

Returns a dict with LLM models supported by 4CAT, either through an API or as a local option. Make sure to keep up-to-date!

Parameters
  • config: Configuration reader

:returns dict, A dict with model IDs as keys and details as values