common.lib.llm
1import json 2 3from typing import List, Optional, Union 4from pydantic import SecretStr 5from langchain_core.messages import BaseMessage 6from langchain_core.language_models.chat_models import BaseChatModel 7from langchain_anthropic import ChatAnthropic 8from langchain_google_genai import ChatGoogleGenerativeAI 9from langchain_ollama import ChatOllama 10from langchain_openai import ChatOpenAI 11from langchain_mistralai import ChatMistralAI 12 13class LLMAdapter: 14 def __init__( 15 self, 16 provider: str, 17 model: str, 18 api_key: Optional[str] = None, 19 base_url: Optional[str] = None, 20 temperature: float = 0.0, 21 ): 22 """ 23 provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral' 24 model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.) 25 api_key: API key if required (OpenAI, Claude, Google, Mistral) 26 base_url: for local models or Mistral custom endpoints 27 """ 28 self.provider = provider.lower() 29 self.model = model 30 self.api_key = api_key 31 self.base_url = base_url 32 self.temperature = temperature 33 self.llm: BaseChatModel = self._load_llm() 34 35 36 def _load_llm(self) -> BaseChatModel: 37 if self.provider == "openai": 38 kwargs = {} 39 if "o3" not in self.model: 40 kwargs["temperature"] = self.temperature # temperature not supported for all models 41 return ChatOpenAI( 42 model=self.model, 43 api_key=SecretStr(self.api_key), 44 base_url=self.base_url or "https://api.openai.com/v1", 45 **kwargs 46 ) 47 elif self.provider == "google": 48 return ChatGoogleGenerativeAI( 49 model=self.model, 50 temperature=self.temperature, 51 google_api_key=self.api_key 52 ) 53 elif self.provider == "anthropic": 54 return ChatAnthropic( 55 model_name=self.model, 56 temperature=self.temperature, 57 api_key=SecretStr(self.api_key), 58 timeout=100, 59 stop=None 60 ) 61 elif self.provider == "mistral": 62 return ChatMistralAI( 63 model_name=self.model, 64 temperature=self.temperature, 65 api_key=SecretStr(self.api_key), 66 base_url=self.base_url # Optional override 67 ) 68 elif self.provider == "ollama": 69 ollama_adapter = ChatOllama( 70 model=self.model, 71 temperature=self.temperature, 72 base_url=self.base_url or "http://localhost:11434" 73 ) 74 self.model = ollama_adapter.model 75 return ollama_adapter 76 elif self.provider in {"vllm", "lmstudio"}: 77 # OpenAI-compatible local servers 78 if self.provider == "lmstudio" and not self.api_key: 79 self.api_key = "lm-studio" 80 return ChatOpenAI( 81 model=self.model, 82 temperature=self.temperature, 83 api_key=SecretStr(self.api_key), 84 base_url=self.base_url 85 ) 86 else: 87 raise ValueError(f"Unsupported LLM provider: {self.provider}") 88 89 def text_generation( 90 self, 91 messages: Union[str, List[BaseMessage]], 92 system_prompt: Optional[str] = None, 93 ) -> str: 94 """ 95 Supports string input or LangChain message list. 96 """ 97 if isinstance(messages, str): 98 from langchain_core.messages import HumanMessage, SystemMessage 99 lc_messages = [] 100 if system_prompt: 101 lc_messages.append(SystemMessage(content=system_prompt)) 102 lc_messages.append(HumanMessage(content=messages)) 103 else: 104 lc_messages = messages 105 106 try: 107 response = self.llm.invoke(lc_messages).content 108 except Exception as e: 109 raise e 110 111 return response 112 113 @staticmethod 114 def get_model_options(config) -> dict: 115 """ 116 Returns model choice options for UserInput 117 118 :param config: Configuration reader 119 """ 120 121 models = LLMAdapter.get_models(config) 122 options = {model_id: model_values["name"] for model_id, model_values in models.items()} 123 return options 124 125 @staticmethod 126 def get_models(config) -> dict: 127 """ 128 Returns a dict with LLM models supported by 4CAT, either through an API or as a local option. 129 Make sure to keep up-to-date! 130 131 :param config: Configuration reader 132 133 :returns dict, A dict with model IDs as keys and details as values 134 """ 135 136 with ( 137 config.get("PATH_ROOT") 138 .joinpath("common/assets/llms.json") 139 .open() as available_models 140 ): 141 available_models = json.loads(available_models.read()) 142 return available_models
class
LLMAdapter:
14class LLMAdapter: 15 def __init__( 16 self, 17 provider: str, 18 model: str, 19 api_key: Optional[str] = None, 20 base_url: Optional[str] = None, 21 temperature: float = 0.0, 22 ): 23 """ 24 provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral' 25 model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.) 26 api_key: API key if required (OpenAI, Claude, Google, Mistral) 27 base_url: for local models or Mistral custom endpoints 28 """ 29 self.provider = provider.lower() 30 self.model = model 31 self.api_key = api_key 32 self.base_url = base_url 33 self.temperature = temperature 34 self.llm: BaseChatModel = self._load_llm() 35 36 37 def _load_llm(self) -> BaseChatModel: 38 if self.provider == "openai": 39 kwargs = {} 40 if "o3" not in self.model: 41 kwargs["temperature"] = self.temperature # temperature not supported for all models 42 return ChatOpenAI( 43 model=self.model, 44 api_key=SecretStr(self.api_key), 45 base_url=self.base_url or "https://api.openai.com/v1", 46 **kwargs 47 ) 48 elif self.provider == "google": 49 return ChatGoogleGenerativeAI( 50 model=self.model, 51 temperature=self.temperature, 52 google_api_key=self.api_key 53 ) 54 elif self.provider == "anthropic": 55 return ChatAnthropic( 56 model_name=self.model, 57 temperature=self.temperature, 58 api_key=SecretStr(self.api_key), 59 timeout=100, 60 stop=None 61 ) 62 elif self.provider == "mistral": 63 return ChatMistralAI( 64 model_name=self.model, 65 temperature=self.temperature, 66 api_key=SecretStr(self.api_key), 67 base_url=self.base_url # Optional override 68 ) 69 elif self.provider == "ollama": 70 ollama_adapter = ChatOllama( 71 model=self.model, 72 temperature=self.temperature, 73 base_url=self.base_url or "http://localhost:11434" 74 ) 75 self.model = ollama_adapter.model 76 return ollama_adapter 77 elif self.provider in {"vllm", "lmstudio"}: 78 # OpenAI-compatible local servers 79 if self.provider == "lmstudio" and not self.api_key: 80 self.api_key = "lm-studio" 81 return ChatOpenAI( 82 model=self.model, 83 temperature=self.temperature, 84 api_key=SecretStr(self.api_key), 85 base_url=self.base_url 86 ) 87 else: 88 raise ValueError(f"Unsupported LLM provider: {self.provider}") 89 90 def text_generation( 91 self, 92 messages: Union[str, List[BaseMessage]], 93 system_prompt: Optional[str] = None, 94 ) -> str: 95 """ 96 Supports string input or LangChain message list. 97 """ 98 if isinstance(messages, str): 99 from langchain_core.messages import HumanMessage, SystemMessage 100 lc_messages = [] 101 if system_prompt: 102 lc_messages.append(SystemMessage(content=system_prompt)) 103 lc_messages.append(HumanMessage(content=messages)) 104 else: 105 lc_messages = messages 106 107 try: 108 response = self.llm.invoke(lc_messages).content 109 except Exception as e: 110 raise e 111 112 return response 113 114 @staticmethod 115 def get_model_options(config) -> dict: 116 """ 117 Returns model choice options for UserInput 118 119 :param config: Configuration reader 120 """ 121 122 models = LLMAdapter.get_models(config) 123 options = {model_id: model_values["name"] for model_id, model_values in models.items()} 124 return options 125 126 @staticmethod 127 def get_models(config) -> dict: 128 """ 129 Returns a dict with LLM models supported by 4CAT, either through an API or as a local option. 130 Make sure to keep up-to-date! 131 132 :param config: Configuration reader 133 134 :returns dict, A dict with model IDs as keys and details as values 135 """ 136 137 with ( 138 config.get("PATH_ROOT") 139 .joinpath("common/assets/llms.json") 140 .open() as available_models 141 ): 142 available_models = json.loads(available_models.read()) 143 return available_models
LLMAdapter( provider: str, model: str, api_key: Optional[str] = None, base_url: Optional[str] = None, temperature: float = 0.0)
15 def __init__( 16 self, 17 provider: str, 18 model: str, 19 api_key: Optional[str] = None, 20 base_url: Optional[str] = None, 21 temperature: float = 0.0, 22 ): 23 """ 24 provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral' 25 model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.) 26 api_key: API key if required (OpenAI, Claude, Google, Mistral) 27 base_url: for local models or Mistral custom endpoints 28 """ 29 self.provider = provider.lower() 30 self.model = model 31 self.api_key = api_key 32 self.base_url = base_url 33 self.temperature = temperature 34 self.llm: BaseChatModel = self._load_llm()
provider: 'openai', 'google', 'mistral', 'ollama', 'vllm', 'lmstudio', 'mistral' model: model name (e.g., 'gpt-4o-mini', 'claude-3-opus', 'mistral-small', etc.) api_key: API key if required (OpenAI, Claude, Google, Mistral) base_url: for local models or Mistral custom endpoints
def
text_generation( self, messages: Union[str, List[langchain_core.messages.base.BaseMessage]], system_prompt: Optional[str] = None) -> str:
90 def text_generation( 91 self, 92 messages: Union[str, List[BaseMessage]], 93 system_prompt: Optional[str] = None, 94 ) -> str: 95 """ 96 Supports string input or LangChain message list. 97 """ 98 if isinstance(messages, str): 99 from langchain_core.messages import HumanMessage, SystemMessage 100 lc_messages = [] 101 if system_prompt: 102 lc_messages.append(SystemMessage(content=system_prompt)) 103 lc_messages.append(HumanMessage(content=messages)) 104 else: 105 lc_messages = messages 106 107 try: 108 response = self.llm.invoke(lc_messages).content 109 except Exception as e: 110 raise e 111 112 return response
Supports string input or LangChain message list.
@staticmethod
def
get_model_options(config) -> dict:
114 @staticmethod 115 def get_model_options(config) -> dict: 116 """ 117 Returns model choice options for UserInput 118 119 :param config: Configuration reader 120 """ 121 122 models = LLMAdapter.get_models(config) 123 options = {model_id: model_values["name"] for model_id, model_values in models.items()} 124 return options
Returns model choice options for UserInput
Parameters
- config: Configuration reader
@staticmethod
def
get_models(config) -> dict:
126 @staticmethod 127 def get_models(config) -> dict: 128 """ 129 Returns a dict with LLM models supported by 4CAT, either through an API or as a local option. 130 Make sure to keep up-to-date! 131 132 :param config: Configuration reader 133 134 :returns dict, A dict with model IDs as keys and details as values 135 """ 136 137 with ( 138 config.get("PATH_ROOT") 139 .joinpath("common/assets/llms.json") 140 .open() as available_models 141 ): 142 available_models = json.loads(available_models.read()) 143 return available_models
Returns a dict with LLM models supported by 4CAT, either through an API or as a local option. Make sure to keep up-to-date!
Parameters
- config: Configuration reader
:returns dict, A dict with model IDs as keys and details as values