You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/jamba_swarm/jamba_llm.py

57 lines
1.5 KiB

import os
from ai21 import AI21Client
from ai21.models.chat import ChatMessage
from dotenv import load_dotenv
from swarms import BaseLLM
load_dotenv()
class Jamba(BaseLLM):
def __init__(
self,
api_key: str = os.getenv("AI21_API_KEY"),
temperature: int = 0.8,
max_tokens: int = 200,
):
"""
Initializes the Jamba class with the provided API key.
Args:
api_key (str): The API key for the AI21Client.
"""
os.environ["AI21_API_KEY"] = api_key
self.api_key = api_key
self.temperature = temperature
self.max_tokens = max_tokens
self.client = AI21Client()
def run(self, prompt: str, *args, **kwargs) -> str:
"""
Generates a response for the given prompt using the AI21 model.
Args:
prompt (str): The prompt for generating the response.
Returns:
str: The generated response.
Raises:
Exception: If there is an issue with the API request.
"""
try:
response = self.client.chat.completions.create(
model="jamba-instruct-preview", # Latest model
messages=[ChatMessage(role="user", content=prompt)],
temperature=self.temperature,
max_tokens=self.max_tokens,
*args,
**kwargs,
)
return response.choices[0].message.content
except Exception as e:
print(f"Error: {e}")
raise