import os from ai21 import AI21Client from ai21.models.chat import ChatMessage from dotenv import load_dotenv from swarms import BaseLLM load_dotenv() class Jamba(BaseLLM): def __init__( self, api_key: str = os.getenv("AI21_API_KEY"), temperature: int = 0.8, max_tokens: int = 200, ): """ Initializes the Jamba class with the provided API key. Args: api_key (str): The API key for the AI21Client. """ os.environ["AI21_API_KEY"] = api_key self.api_key = api_key self.temperature = temperature self.max_tokens = max_tokens self.client = AI21Client() def run(self, prompt: str, *args, **kwargs) -> str: """ Generates a response for the given prompt using the AI21 model. Args: prompt (str): The prompt for generating the response. Returns: str: The generated response. Raises: Exception: If there is an issue with the API request. """ try: response = self.client.chat.completions.create( model="jamba-instruct-preview", # Latest model messages=[ChatMessage(role="user", content=prompt)], temperature=self.temperature, max_tokens=self.max_tokens, *args, **kwargs, ) return response.choices[0].message.content except Exception as e: print(f"Error: {e}") raise