parent
7db6930bd4
commit
5f56023dc3
@ -1,8 +0,0 @@
|
||||
version: 0.0.1
|
||||
patterns:
|
||||
- name: github.com/getgrit/js#*
|
||||
- name: github.com/getgrit/python#*
|
||||
- name: github.com/getgrit/json#*
|
||||
- name: github.com/getgrit/hcl#*
|
||||
- name: github.com/getgrit/python#openai
|
||||
level: info
|
@ -0,0 +1,217 @@
|
||||
# !pip install accelerate
|
||||
# !pip install torch
|
||||
# !pip install transformers
|
||||
# !pip install bitsandbytes
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
BitsAndBytesConfig,
|
||||
TextStreamer,
|
||||
)
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
|
||||
class LlamaFunctionCaller:
|
||||
"""
|
||||
A class to manage and execute Llama functions.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
model: transformers.AutoModelForCausalLM
|
||||
The loaded Llama model.
|
||||
tokenizer: transformers.AutoTokenizer
|
||||
The tokenizer for the Llama model.
|
||||
functions: Dict[str, Callable]
|
||||
A dictionary of functions available for execution.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
__init__(self, model_id: str, cache_dir: str, runtime: str)
|
||||
Initializes the LlamaFunctionCaller with the specified model.
|
||||
add_func(self, name: str, function: Callable, description: str, arguments: List[Dict])
|
||||
Adds a new function to the LlamaFunctionCaller.
|
||||
call_function(self, name: str, **kwargs)
|
||||
Calls the specified function with given arguments.
|
||||
stream(self, user_prompt: str)
|
||||
Streams a user prompt to the model and prints the response.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
# Example usage
|
||||
model_id = "Your-Model-ID"
|
||||
cache_dir = "Your-Cache-Directory"
|
||||
runtime = "cuda" # or 'cpu'
|
||||
|
||||
llama_caller = LlamaFunctionCaller(model_id, cache_dir, runtime)
|
||||
|
||||
|
||||
# Add a custom function
|
||||
def get_weather(location: str, format: str) -> str:
|
||||
# This is a placeholder for the actual implementation
|
||||
return f"Weather at {location} in {format} format."
|
||||
|
||||
|
||||
llama_caller.add_func(
|
||||
name="get_weather",
|
||||
function=get_weather,
|
||||
description="Get the weather at a location",
|
||||
arguments=[
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "Location for the weather",
|
||||
},
|
||||
{
|
||||
"name": "format",
|
||||
"type": "string",
|
||||
"description": "Format of the weather data",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = llama_caller.call_function("get_weather", location="Paris", format="Celsius")
|
||||
print(result)
|
||||
|
||||
# Stream a user prompt
|
||||
llama_caller("Tell me about the tallest mountain in the world.")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||
cache_dir: str = "llama_cache",
|
||||
runtime: str = "auto",
|
||||
max_tokens: int = 500,
|
||||
streaming: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.cache_dir = cache_dir
|
||||
self.runtime = runtime
|
||||
self.max_tokens = max_tokens
|
||||
self.streaming = streaming
|
||||
|
||||
# Load the model and tokenizer
|
||||
self.model = self._load_model()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, cache_dir=cache_dir, use_fast=True
|
||||
)
|
||||
self.functions = {}
|
||||
|
||||
def _load_model(self):
|
||||
# Configuration for loading the model
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
return AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id,
|
||||
quantization_config=bnb_config,
|
||||
device_map=self.runtime,
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
|
||||
def add_func(
|
||||
self, name: str, function: Callable, description: str, arguments: List[Dict]
|
||||
):
|
||||
"""
|
||||
Adds a new function to the LlamaFunctionCaller.
|
||||
|
||||
Args:
|
||||
name (str): The name of the function.
|
||||
function (Callable): The function to execute.
|
||||
description (str): Description of the function.
|
||||
arguments (List[Dict]): List of argument specifications.
|
||||
"""
|
||||
self.functions[name] = {
|
||||
"function": function,
|
||||
"description": description,
|
||||
"arguments": arguments,
|
||||
}
|
||||
|
||||
def call_function(self, name: str, **kwargs):
|
||||
"""
|
||||
Calls the specified function with given arguments.
|
||||
|
||||
Args:
|
||||
name (str): The name of the function to call.
|
||||
**kwargs: Keyword arguments for the function call.
|
||||
|
||||
Returns:
|
||||
The result of the function call.
|
||||
"""
|
||||
if name not in self.functions:
|
||||
raise ValueError(f"Function {name} not found.")
|
||||
|
||||
func_info = self.functions[name]
|
||||
return func_info["function"](**kwargs)
|
||||
|
||||
def __call__(self, task: str, **kwargs):
|
||||
"""
|
||||
Streams a user prompt to the model and prints the response.
|
||||
|
||||
Args:
|
||||
task (str): The user prompt to stream.
|
||||
"""
|
||||
# Format the prompt
|
||||
prompt = f"{task}\n\n"
|
||||
|
||||
# Encode and send to the model
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.runtime)
|
||||
|
||||
streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
if self.streaming:
|
||||
out = self.model.generate(
|
||||
**inputs, streamer=streamer, max_new_tokens=self.max_tokens, **kwargs
|
||||
)
|
||||
|
||||
return out
|
||||
else:
|
||||
out = self.model.generate(**inputs, max_length=self.max_tokens, **kwargs)
|
||||
# return self.tokenizer.decode(out[0], skip_special_tokens=True)
|
||||
return out
|
||||
|
||||
|
||||
llama_caller = LlamaFunctionCaller()
|
||||
|
||||
|
||||
# Add a custom function
|
||||
def get_weather(location: str, format: str) -> str:
|
||||
# This is a placeholder for the actual implementation
|
||||
return f"Weather at {location} in {format} format."
|
||||
|
||||
|
||||
llama_caller.add_func(
|
||||
name="get_weather",
|
||||
function=get_weather,
|
||||
description="Get the weather at a location",
|
||||
arguments=[
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "Location for the weather",
|
||||
},
|
||||
{
|
||||
"name": "format",
|
||||
"type": "string",
|
||||
"description": "Format of the weather data",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = llama_caller.call_function("get_weather", location="Paris", format="Celsius")
|
||||
print(result)
|
||||
|
||||
# Stream a user prompt
|
||||
llama_caller("Tell me about the tallest mountain in the world.")
|
@ -0,0 +1 @@
|
||||
""""""
|
@ -0,0 +1,246 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from pydantic import BaseModel, validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class FunctionSpecification(BaseModel):
|
||||
"""
|
||||
Defines the specification for a function including its parameters and metadata.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
name: str
|
||||
The name of the function.
|
||||
description: str
|
||||
A brief description of what the function does.
|
||||
parameters: Dict[str, Any]
|
||||
The parameters required by the function, with their details.
|
||||
required: Optional[List[str]]
|
||||
List of required parameter names.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
validate_params(params: Dict[str, Any]) -> None:
|
||||
Validates the parameters against the function's specification.
|
||||
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
# Example Usage
|
||||
def get_current_weather(location: str, format: str) -> str:
|
||||
``'
|
||||
Example function to get current weather.
|
||||
|
||||
Args:
|
||||
location (str): The city and state, e.g. San Francisco, CA.
|
||||
format (str): The temperature unit, e.g. celsius or fahrenheit.
|
||||
|
||||
Returns:
|
||||
str: Weather information.
|
||||
'''
|
||||
# Implementation goes here
|
||||
return "Sunny, 23°C"
|
||||
|
||||
|
||||
weather_function_spec = FunctionSpecification(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
parameters={
|
||||
"location": {"type": "string", "description": "The city and state"},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
# Validating parameters for the function
|
||||
params = {"location": "San Francisco, CA", "format": "celsius"}
|
||||
weather_function_spec.validate_params(params)
|
||||
|
||||
# Calling the function
|
||||
print(get_current_weather(**params))
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
required: Optional[List[str]] = None
|
||||
|
||||
@validator("parameters")
|
||||
def check_parameters(cls, params):
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Parameters must be a dictionary.")
|
||||
return params
|
||||
|
||||
def validate_params(self, params: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validates the parameters against the function's specification.
|
||||
|
||||
Args:
|
||||
params (Dict[str, Any]): The parameters to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If any required parameter is missing or if any parameter is invalid.
|
||||
"""
|
||||
for key, value in params.items():
|
||||
if key in self.parameters:
|
||||
self.parameters[key]
|
||||
# Perform specific validation based on param_spec
|
||||
# This can include type checking, range validation, etc.
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter: {key}")
|
||||
|
||||
for req_param in self.required or []:
|
||||
if req_param not in params:
|
||||
raise ValueError(f"Missing required parameter: {req_param}")
|
||||
|
||||
|
||||
class OpenAIFunctionCaller:
|
||||
def __init__(
|
||||
self,
|
||||
openai_api_key: str,
|
||||
model: str = "text-davinci-003",
|
||||
max_tokens: int = 3000,
|
||||
temperature: float = 0.5,
|
||||
top_p: float = 1.0,
|
||||
n: int = 1,
|
||||
stream: bool = False,
|
||||
stop: Optional[str] = None,
|
||||
echo: bool = False,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
logprobs: Optional[int] = None,
|
||||
best_of: int = 1,
|
||||
logit_bias: Dict[str, float] = None,
|
||||
user: str = None,
|
||||
messages: List[Dict] = None,
|
||||
timeout_sec: Union[float, None] = None,
|
||||
):
|
||||
self.openai_api_key = openai_api_key
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.stream = stream
|
||||
self.stop = stop
|
||||
self.echo = echo
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.logprobs = logprobs
|
||||
self.best_of = best_of
|
||||
self.logit_bias = logit_bias
|
||||
self.user = user
|
||||
self.messages = messages if messages is not None else []
|
||||
self.timeout_sec = timeout_sec
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
self.messages.append({"role": role, "content": content})
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3)
|
||||
)
|
||||
def chat_completion_request(
|
||||
self,
|
||||
messages,
|
||||
tools=None,
|
||||
tool_choice=None,
|
||||
):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + openai.api_key,
|
||||
}
|
||||
json_data = {"model": self.model, "messages": messages}
|
||||
if tools is not None:
|
||||
json_data.update({"tools": tools})
|
||||
if tool_choice is not None:
|
||||
json_data.update({"tool_choice": tool_choice})
|
||||
try:
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=json_data,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
print("Unable to generate ChatCompletion response")
|
||||
print(f"Exception: {e}")
|
||||
return e
|
||||
|
||||
def pretty_print_conversation(self, messages):
|
||||
role_to_color = {
|
||||
"system": "red",
|
||||
"user": "green",
|
||||
"assistant": "blue",
|
||||
"tool": "magenta",
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
print(
|
||||
colored(
|
||||
f"system: {message['content']}\n",
|
||||
role_to_color[message["role"]],
|
||||
)
|
||||
)
|
||||
elif message["role"] == "user":
|
||||
print(
|
||||
colored(
|
||||
f"user: {message['content']}\n", role_to_color[message["role"]]
|
||||
)
|
||||
)
|
||||
elif message["role"] == "assistant" and message.get("function_call"):
|
||||
print(
|
||||
colored(
|
||||
f"assistant: {message['function_call']}\n",
|
||||
role_to_color[message["role"]],
|
||||
)
|
||||
)
|
||||
elif message["role"] == "assistant" and not message.get("function_call"):
|
||||
print(
|
||||
colored(
|
||||
f"assistant: {message['content']}\n",
|
||||
role_to_color[message["role"]],
|
||||
)
|
||||
)
|
||||
elif message["role"] == "tool":
|
||||
print(
|
||||
colored(
|
||||
f"function ({message['name']}): {message['content']}\n",
|
||||
role_to_color[message["role"]],
|
||||
)
|
||||
)
|
||||
|
||||
def call(self, prompt: str) -> Dict:
|
||||
response = openai.Completion.create(
|
||||
engine=self.model,
|
||||
prompt=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
n=self.n,
|
||||
stream=self.stream,
|
||||
stop=self.stop,
|
||||
echo=self.echo,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
presence_penalty=self.presence_penalty,
|
||||
logprobs=self.logprobs,
|
||||
best_of=self.best_of,
|
||||
logit_bias=self.logit_bias,
|
||||
user=self.user,
|
||||
messages=self.messages,
|
||||
timeout_sec=self.timeout_sec,
|
||||
)
|
||||
return response
|
||||
|
||||
def run(self, prompt: str) -> str:
|
||||
response = self.call(prompt)
|
||||
return response["choices"][0]["text"].strip()
|
@ -0,0 +1,253 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
|
||||
import backoff
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from PIL import Image
|
||||
from pydantic import validator
|
||||
from termcolor import colored
|
||||
from cachetools import TTLCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSD1B:
|
||||
"""
|
||||
SSD1B model class
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
image_url: str
|
||||
The image url generated by the SSD1B API
|
||||
|
||||
Methods:
|
||||
--------
|
||||
__call__(self, task: str) -> SSD1B:
|
||||
Makes a call to the SSD1B API and returns the image url
|
||||
|
||||
Example:
|
||||
--------
|
||||
model = SSD1B()
|
||||
task = "A painting of a dog"
|
||||
neg_prompt = "ugly, blurry, poor quality"
|
||||
image_url = model(task, neg_prompt)
|
||||
print(image_url)
|
||||
"""
|
||||
|
||||
model: str = "dall-e-3"
|
||||
img: str = None
|
||||
size: str = "1024x1024"
|
||||
max_retries: int = 3
|
||||
quality: str = "standard"
|
||||
model_name: str = "segment/SSD-1B"
|
||||
n: int = 1
|
||||
save_path: str = "images"
|
||||
max_time_seconds: int = 60
|
||||
save_folder: str = "images"
|
||||
image_format: str = "png"
|
||||
device: str = "cuda"
|
||||
dashboard: bool = False
|
||||
cache = TTLCache(maxsize=100, ttl=3600)
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"segmind/SSD-1B",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method"""
|
||||
|
||||
if self.img is not None:
|
||||
self.img = self.convert_to_bytesio(self.img)
|
||||
|
||||
os.makedirs(self.save_path, exist_ok=True)
|
||||
|
||||
class Config:
|
||||
"""Config class for the SSD1B model"""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("max_retries", "time_seconds")
|
||||
def must_be_positive(cls, value):
|
||||
if value <= 0:
|
||||
raise ValueError("Must be positive")
|
||||
return value
|
||||
|
||||
def read_img(self, img: str):
|
||||
"""Read the image using pil"""
|
||||
img = Image.open(img)
|
||||
return img
|
||||
|
||||
def set_width_height(self, img: str, width: int, height: int):
|
||||
"""Set the width and height of the image"""
|
||||
img = self.read_img(img)
|
||||
img = img.resize((width, height))
|
||||
return img
|
||||
|
||||
def convert_to_bytesio(self, img: str, format: str = "PNG"):
|
||||
"""Convert the image to an bytes io object"""
|
||||
byte_stream = BytesIO()
|
||||
img.save(byte_stream, format=format)
|
||||
byte_array = byte_stream.getvalue()
|
||||
return byte_array
|
||||
|
||||
@backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
|
||||
def __call__(self, task: str, neg_prompt: str):
|
||||
"""
|
||||
Text to image conversion using the SSD1B API
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task: str
|
||||
The task to be converted to an image
|
||||
|
||||
Returns:
|
||||
--------
|
||||
SSD1B:
|
||||
An instance of the SSD1B class with the image url generated by the SSD1B API
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> dalle3 = SSD1B()
|
||||
>>> task = "A painting of a dog"
|
||||
>>> image_url = dalle3(task)
|
||||
>>> print(image_url)
|
||||
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
|
||||
"""
|
||||
if self.dashboard:
|
||||
self.print_dashboard()
|
||||
if task in self.cache:
|
||||
return self.cache[task]
|
||||
try:
|
||||
img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0]
|
||||
|
||||
# Generate a unique filename for the image
|
||||
img_name = f"{uuid.uuid4()}.{self.image_format}"
|
||||
img_path = os.path.join(self.save_path, img_name)
|
||||
|
||||
# Save the image
|
||||
img.save(img_path, self.image_format)
|
||||
self.cache[task] = img_path
|
||||
|
||||
return img_path
|
||||
|
||||
except Exception as error:
|
||||
# Handling exceptions and printing the errors details
|
||||
print(
|
||||
colored(
|
||||
(
|
||||
f"Error running SSD1B: {error} try optimizing your api key and"
|
||||
" or try again"
|
||||
),
|
||||
"red",
|
||||
)
|
||||
)
|
||||
raise error
|
||||
|
||||
def _generate_image_name(self, task: str):
|
||||
"""Generate a sanitized file name based on the task"""
|
||||
sanitized_task = "".join(
|
||||
char for char in task if char.isalnum() or char in " _ -"
|
||||
).rstrip()
|
||||
return f"{sanitized_task}.{self.image_format}"
|
||||
|
||||
def _download_image(self, img: Image, filename: str):
|
||||
"""
|
||||
Save the PIL Image object to a file.
|
||||
"""
|
||||
full_path = os.path.join(self.save_path, filename)
|
||||
img.save(full_path, self.image_format)
|
||||
|
||||
def print_dashboard(self):
|
||||
"""Print the SSD1B dashboard"""
|
||||
print(
|
||||
colored(
|
||||
(
|
||||
f"""SSD1B Dashboard:
|
||||
--------------------
|
||||
|
||||
Model: {self.model}
|
||||
Image: {self.img}
|
||||
Size: {self.size}
|
||||
Max Retries: {self.max_retries}
|
||||
Quality: {self.quality}
|
||||
N: {self.n}
|
||||
Save Path: {self.save_path}
|
||||
Time Seconds: {self.time_seconds}
|
||||
Save Folder: {self.save_folder}
|
||||
Image Format: {self.image_format}
|
||||
--------------------
|
||||
|
||||
|
||||
"""
|
||||
),
|
||||
"green",
|
||||
)
|
||||
)
|
||||
|
||||
def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5):
|
||||
"""
|
||||
|
||||
Process a batch of tasks concurrently
|
||||
|
||||
Args:
|
||||
tasks (List[str]): A list of tasks to be processed
|
||||
max_workers (int): The maximum number of workers to use for the concurrent processing
|
||||
|
||||
Returns:
|
||||
--------
|
||||
results (List[str]): A list of image urls generated by the SSD1B API
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> model = SSD1B()
|
||||
>>> tasks = ["A painting of a dog", "A painting of a cat"]
|
||||
>>> results = model.process_batch_concurrently(tasks)
|
||||
>>> print(results)
|
||||
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_task = {executor.submit(self, task): task for task in tasks}
|
||||
results = []
|
||||
for future in concurrent.futures.as_completed(future_to_task):
|
||||
task = future_to_task[future]
|
||||
try:
|
||||
img = future.result()
|
||||
results.append(img)
|
||||
|
||||
print(f"Task {task} completed: {img}")
|
||||
except Exception as error:
|
||||
print(
|
||||
colored(
|
||||
(
|
||||
f"Error running SSD1B: {error} try optimizing your api key and"
|
||||
" or try again"
|
||||
),
|
||||
"red",
|
||||
)
|
||||
)
|
||||
print(colored(f"Error running SSD1B: {error.http_status}", "red"))
|
||||
print(colored(f"Error running SSD1B: {error.error}", "red"))
|
||||
raise error
|
||||
|
||||
def _generate_uuid(self):
|
||||
"""Generate a uuid"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def __repr__(self):
|
||||
"""Repr method for the SSD1B class"""
|
||||
return f"SSD1B(image_url={self.image_url})"
|
||||
|
||||
def __str__(self):
|
||||
"""Str method for the SSD1B class"""
|
||||
return f"SSD1B(image_url={self.image_url})"
|
||||
|
||||
@backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
|
||||
def rate_limited_call(self, task: str):
|
||||
"""Rate limited call to the SSD1B API"""
|
||||
return self.__call__(task)
|
@ -0,0 +1,97 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
class Yi34B200k:
|
||||
"""
|
||||
A class for eaasy interaction with Yi34B200k
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
model_id: str
|
||||
The model id of the model to be used.
|
||||
device_map: str
|
||||
The device to be used for inference.
|
||||
torch_dtype: str
|
||||
The torch dtype to be used for inference.
|
||||
max_length: int
|
||||
The maximum length of the generated text.
|
||||
repitition_penalty: float
|
||||
The repitition penalty to be used for inference.
|
||||
no_repeat_ngram_size: int
|
||||
The no repeat ngram size to be used for inference.
|
||||
temperature: float
|
||||
The temperature to be used for inference.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
__call__(self, task: str) -> str:
|
||||
Generates text based on the given prompt.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "01-ai/Yi-34B-200K",
|
||||
device_map: str = "auto",
|
||||
torch_dtype: str = "auto",
|
||||
max_length: int = 512,
|
||||
repitition_penalty: float = 1.3,
|
||||
no_repeat_ngram_size: int = 5,
|
||||
temperature: float = 0.7,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.8,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
self.device_map = device_map
|
||||
self.torch_dtype = torch_dtype
|
||||
self.max_length = max_length
|
||||
self.repitition_penalty = repitition_penalty
|
||||
self.no_repeat_ngram_size = no_repeat_ngram_size
|
||||
self.temperature = temperature
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
def __call__(self, task: str):
|
||||
"""
|
||||
Generates text based on the given prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text prompt.
|
||||
max_length (int): The maximum length of the generated text.
|
||||
|
||||
Returns:
|
||||
str: The generated text.
|
||||
"""
|
||||
inputs = self.tokenizer(task, return_tensors="pt")
|
||||
outputs = self.model.generate(
|
||||
inputs.input_ids.cuda(),
|
||||
max_length=self.max_length,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
do_sample=True,
|
||||
repetition_penalty=self.repitition_penalty,
|
||||
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
||||
temperature=self.temperature,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
)
|
||||
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
# # Example usage
|
||||
# yi34b = Yi34B200k()
|
||||
# prompt = "There's a place where time stands still. A place of breathtaking wonder, but also"
|
||||
# generated_text = yi34b(prompt)
|
||||
# print(generated_text)
|
Loading…
Reference in new issue