parent
7db6930bd4
commit
5f56023dc3
@ -1,8 +0,0 @@
|
|||||||
version: 0.0.1
|
|
||||||
patterns:
|
|
||||||
- name: github.com/getgrit/js#*
|
|
||||||
- name: github.com/getgrit/python#*
|
|
||||||
- name: github.com/getgrit/json#*
|
|
||||||
- name: github.com/getgrit/hcl#*
|
|
||||||
- name: github.com/getgrit/python#openai
|
|
||||||
level: info
|
|
@ -0,0 +1,217 @@
|
|||||||
|
# !pip install accelerate
|
||||||
|
# !pip install torch
|
||||||
|
# !pip install transformers
|
||||||
|
# !pip install bitsandbytes
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
TextStreamer,
|
||||||
|
)
|
||||||
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaFunctionCaller:
|
||||||
|
"""
|
||||||
|
A class to manage and execute Llama functions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
model: transformers.AutoModelForCausalLM
|
||||||
|
The loaded Llama model.
|
||||||
|
tokenizer: transformers.AutoTokenizer
|
||||||
|
The tokenizer for the Llama model.
|
||||||
|
functions: Dict[str, Callable]
|
||||||
|
A dictionary of functions available for execution.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
--------
|
||||||
|
__init__(self, model_id: str, cache_dir: str, runtime: str)
|
||||||
|
Initializes the LlamaFunctionCaller with the specified model.
|
||||||
|
add_func(self, name: str, function: Callable, description: str, arguments: List[Dict])
|
||||||
|
Adds a new function to the LlamaFunctionCaller.
|
||||||
|
call_function(self, name: str, **kwargs)
|
||||||
|
Calls the specified function with given arguments.
|
||||||
|
stream(self, user_prompt: str)
|
||||||
|
Streams a user prompt to the model and prints the response.
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
model_id = "Your-Model-ID"
|
||||||
|
cache_dir = "Your-Cache-Directory"
|
||||||
|
runtime = "cuda" # or 'cpu'
|
||||||
|
|
||||||
|
llama_caller = LlamaFunctionCaller(model_id, cache_dir, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
# Add a custom function
|
||||||
|
def get_weather(location: str, format: str) -> str:
|
||||||
|
# This is a placeholder for the actual implementation
|
||||||
|
return f"Weather at {location} in {format} format."
|
||||||
|
|
||||||
|
|
||||||
|
llama_caller.add_func(
|
||||||
|
name="get_weather",
|
||||||
|
function=get_weather,
|
||||||
|
description="Get the weather at a location",
|
||||||
|
arguments=[
|
||||||
|
{
|
||||||
|
"name": "location",
|
||||||
|
"type": "string",
|
||||||
|
"description": "Location for the weather",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "Format of the weather data",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = llama_caller.call_function("get_weather", location="Paris", format="Celsius")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# Stream a user prompt
|
||||||
|
llama_caller("Tell me about the tallest mountain in the world.")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str = "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
|
||||||
|
cache_dir: str = "llama_cache",
|
||||||
|
runtime: str = "auto",
|
||||||
|
max_tokens: int = 500,
|
||||||
|
streaming: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model_id = model_id
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.runtime = runtime
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.streaming = streaming
|
||||||
|
|
||||||
|
# Load the model and tokenizer
|
||||||
|
self.model = self._load_model()
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id, cache_dir=cache_dir, use_fast=True
|
||||||
|
)
|
||||||
|
self.functions = {}
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
# Configuration for loading the model
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
return AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_id,
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
device_map=self.runtime,
|
||||||
|
trust_remote_code=True,
|
||||||
|
cache_dir=self.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_func(
|
||||||
|
self, name: str, function: Callable, description: str, arguments: List[Dict]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds a new function to the LlamaFunctionCaller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the function.
|
||||||
|
function (Callable): The function to execute.
|
||||||
|
description (str): Description of the function.
|
||||||
|
arguments (List[Dict]): List of argument specifications.
|
||||||
|
"""
|
||||||
|
self.functions[name] = {
|
||||||
|
"function": function,
|
||||||
|
"description": description,
|
||||||
|
"arguments": arguments,
|
||||||
|
}
|
||||||
|
|
||||||
|
def call_function(self, name: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Calls the specified function with given arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the function to call.
|
||||||
|
**kwargs: Keyword arguments for the function call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the function call.
|
||||||
|
"""
|
||||||
|
if name not in self.functions:
|
||||||
|
raise ValueError(f"Function {name} not found.")
|
||||||
|
|
||||||
|
func_info = self.functions[name]
|
||||||
|
return func_info["function"](**kwargs)
|
||||||
|
|
||||||
|
def __call__(self, task: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Streams a user prompt to the model and prints the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The user prompt to stream.
|
||||||
|
"""
|
||||||
|
# Format the prompt
|
||||||
|
prompt = f"{task}\n\n"
|
||||||
|
|
||||||
|
# Encode and send to the model
|
||||||
|
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.runtime)
|
||||||
|
|
||||||
|
streamer = TextStreamer(self.tokenizer)
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
out = self.model.generate(
|
||||||
|
**inputs, streamer=streamer, max_new_tokens=self.max_tokens, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
out = self.model.generate(**inputs, max_length=self.max_tokens, **kwargs)
|
||||||
|
# return self.tokenizer.decode(out[0], skip_special_tokens=True)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
llama_caller = LlamaFunctionCaller()
|
||||||
|
|
||||||
|
|
||||||
|
# Add a custom function
|
||||||
|
def get_weather(location: str, format: str) -> str:
|
||||||
|
# This is a placeholder for the actual implementation
|
||||||
|
return f"Weather at {location} in {format} format."
|
||||||
|
|
||||||
|
|
||||||
|
llama_caller.add_func(
|
||||||
|
name="get_weather",
|
||||||
|
function=get_weather,
|
||||||
|
description="Get the weather at a location",
|
||||||
|
arguments=[
|
||||||
|
{
|
||||||
|
"name": "location",
|
||||||
|
"type": "string",
|
||||||
|
"description": "Location for the weather",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "Format of the weather data",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = llama_caller.call_function("get_weather", location="Paris", format="Celsius")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# Stream a user prompt
|
||||||
|
llama_caller("Tell me about the tallest mountain in the world.")
|
@ -0,0 +1 @@
|
|||||||
|
""""""
|
@ -0,0 +1,246 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionSpecification(BaseModel):
|
||||||
|
"""
|
||||||
|
Defines the specification for a function including its parameters and metadata.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
name: str
|
||||||
|
The name of the function.
|
||||||
|
description: str
|
||||||
|
A brief description of what the function does.
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
The parameters required by the function, with their details.
|
||||||
|
required: Optional[List[str]]
|
||||||
|
List of required parameter names.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
--------
|
||||||
|
validate_params(params: Dict[str, Any]) -> None:
|
||||||
|
Validates the parameters against the function's specification.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
# Example Usage
|
||||||
|
def get_current_weather(location: str, format: str) -> str:
|
||||||
|
``'
|
||||||
|
Example function to get current weather.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (str): The city and state, e.g. San Francisco, CA.
|
||||||
|
format (str): The temperature unit, e.g. celsius or fahrenheit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Weather information.
|
||||||
|
'''
|
||||||
|
# Implementation goes here
|
||||||
|
return "Sunny, 23°C"
|
||||||
|
|
||||||
|
|
||||||
|
weather_function_spec = FunctionSpecification(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
parameters={
|
||||||
|
"location": {"type": "string", "description": "The city and state"},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required=["location", "format"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validating parameters for the function
|
||||||
|
params = {"location": "San Francisco, CA", "format": "celsius"}
|
||||||
|
weather_function_spec.validate_params(params)
|
||||||
|
|
||||||
|
# Calling the function
|
||||||
|
print(get_current_weather(**params))
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
required: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@validator("parameters")
|
||||||
|
def check_parameters(cls, params):
|
||||||
|
if not isinstance(params, dict):
|
||||||
|
raise ValueError("Parameters must be a dictionary.")
|
||||||
|
return params
|
||||||
|
|
||||||
|
def validate_params(self, params: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validates the parameters against the function's specification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (Dict[str, Any]): The parameters to validate.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any required parameter is missing or if any parameter is invalid.
|
||||||
|
"""
|
||||||
|
for key, value in params.items():
|
||||||
|
if key in self.parameters:
|
||||||
|
self.parameters[key]
|
||||||
|
# Perform specific validation based on param_spec
|
||||||
|
# This can include type checking, range validation, etc.
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected parameter: {key}")
|
||||||
|
|
||||||
|
for req_param in self.required or []:
|
||||||
|
if req_param not in params:
|
||||||
|
raise ValueError(f"Missing required parameter: {req_param}")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFunctionCaller:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
openai_api_key: str,
|
||||||
|
model: str = "text-davinci-003",
|
||||||
|
max_tokens: int = 3000,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
n: int = 1,
|
||||||
|
stream: bool = False,
|
||||||
|
stop: Optional[str] = None,
|
||||||
|
echo: bool = False,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
logprobs: Optional[int] = None,
|
||||||
|
best_of: int = 1,
|
||||||
|
logit_bias: Dict[str, float] = None,
|
||||||
|
user: str = None,
|
||||||
|
messages: List[Dict] = None,
|
||||||
|
timeout_sec: Union[float, None] = None,
|
||||||
|
):
|
||||||
|
self.openai_api_key = openai_api_key
|
||||||
|
self.model = model
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.n = n
|
||||||
|
self.stream = stream
|
||||||
|
self.stop = stop
|
||||||
|
self.echo = echo
|
||||||
|
self.frequency_penalty = frequency_penalty
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
self.logprobs = logprobs
|
||||||
|
self.best_of = best_of
|
||||||
|
self.logit_bias = logit_bias
|
||||||
|
self.user = user
|
||||||
|
self.messages = messages if messages is not None else []
|
||||||
|
self.timeout_sec = timeout_sec
|
||||||
|
|
||||||
|
def add_message(self, role: str, content: str):
|
||||||
|
self.messages.append({"role": role, "content": content})
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3)
|
||||||
|
)
|
||||||
|
def chat_completion_request(
|
||||||
|
self,
|
||||||
|
messages,
|
||||||
|
tools=None,
|
||||||
|
tool_choice=None,
|
||||||
|
):
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer " + openai.api_key,
|
||||||
|
}
|
||||||
|
json_data = {"model": self.model, "messages": messages}
|
||||||
|
if tools is not None:
|
||||||
|
json_data.update({"tools": tools})
|
||||||
|
if tool_choice is not None:
|
||||||
|
json_data.update({"tool_choice": tool_choice})
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
"https://api.openai.com/v1/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
print("Unable to generate ChatCompletion response")
|
||||||
|
print(f"Exception: {e}")
|
||||||
|
return e
|
||||||
|
|
||||||
|
def pretty_print_conversation(self, messages):
|
||||||
|
role_to_color = {
|
||||||
|
"system": "red",
|
||||||
|
"user": "green",
|
||||||
|
"assistant": "blue",
|
||||||
|
"tool": "magenta",
|
||||||
|
}
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"system: {message['content']}\n",
|
||||||
|
role_to_color[message["role"]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif message["role"] == "user":
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"user: {message['content']}\n", role_to_color[message["role"]]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif message["role"] == "assistant" and message.get("function_call"):
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"assistant: {message['function_call']}\n",
|
||||||
|
role_to_color[message["role"]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif message["role"] == "assistant" and not message.get("function_call"):
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"assistant: {message['content']}\n",
|
||||||
|
role_to_color[message["role"]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
f"function ({message['name']}): {message['content']}\n",
|
||||||
|
role_to_color[message["role"]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, prompt: str) -> Dict:
|
||||||
|
response = openai.Completion.create(
|
||||||
|
engine=self.model,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
n=self.n,
|
||||||
|
stream=self.stream,
|
||||||
|
stop=self.stop,
|
||||||
|
echo=self.echo,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
logprobs=self.logprobs,
|
||||||
|
best_of=self.best_of,
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
|
user=self.user,
|
||||||
|
messages=self.messages,
|
||||||
|
timeout_sec=self.timeout_sec,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def run(self, prompt: str) -> str:
|
||||||
|
response = self.call(prompt)
|
||||||
|
return response["choices"][0]["text"].strip()
|
@ -0,0 +1,253 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import backoff
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionXLPipeline
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import validator
|
||||||
|
from termcolor import colored
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SSD1B:
|
||||||
|
"""
|
||||||
|
SSD1B model class
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
image_url: str
|
||||||
|
The image url generated by the SSD1B API
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
--------
|
||||||
|
__call__(self, task: str) -> SSD1B:
|
||||||
|
Makes a call to the SSD1B API and returns the image url
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
model = SSD1B()
|
||||||
|
task = "A painting of a dog"
|
||||||
|
neg_prompt = "ugly, blurry, poor quality"
|
||||||
|
image_url = model(task, neg_prompt)
|
||||||
|
print(image_url)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "dall-e-3"
|
||||||
|
img: str = None
|
||||||
|
size: str = "1024x1024"
|
||||||
|
max_retries: int = 3
|
||||||
|
quality: str = "standard"
|
||||||
|
model_name: str = "segment/SSD-1B"
|
||||||
|
n: int = 1
|
||||||
|
save_path: str = "images"
|
||||||
|
max_time_seconds: int = 60
|
||||||
|
save_folder: str = "images"
|
||||||
|
image_format: str = "png"
|
||||||
|
device: str = "cuda"
|
||||||
|
dashboard: bool = False
|
||||||
|
cache = TTLCache(maxsize=100, ttl=3600)
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
"segmind/SSD-1B",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
use_safetensors=True,
|
||||||
|
variant="fp16",
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Post init method"""
|
||||||
|
|
||||||
|
if self.img is not None:
|
||||||
|
self.img = self.convert_to_bytesio(self.img)
|
||||||
|
|
||||||
|
os.makedirs(self.save_path, exist_ok=True)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Config class for the SSD1B model"""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@validator("max_retries", "time_seconds")
|
||||||
|
def must_be_positive(cls, value):
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError("Must be positive")
|
||||||
|
return value
|
||||||
|
|
||||||
|
def read_img(self, img: str):
|
||||||
|
"""Read the image using pil"""
|
||||||
|
img = Image.open(img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def set_width_height(self, img: str, width: int, height: int):
|
||||||
|
"""Set the width and height of the image"""
|
||||||
|
img = self.read_img(img)
|
||||||
|
img = img.resize((width, height))
|
||||||
|
return img
|
||||||
|
|
||||||
|
def convert_to_bytesio(self, img: str, format: str = "PNG"):
|
||||||
|
"""Convert the image to an bytes io object"""
|
||||||
|
byte_stream = BytesIO()
|
||||||
|
img.save(byte_stream, format=format)
|
||||||
|
byte_array = byte_stream.getvalue()
|
||||||
|
return byte_array
|
||||||
|
|
||||||
|
@backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
|
||||||
|
def __call__(self, task: str, neg_prompt: str):
|
||||||
|
"""
|
||||||
|
Text to image conversion using the SSD1B API
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
task: str
|
||||||
|
The task to be converted to an image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
SSD1B:
|
||||||
|
An instance of the SSD1B class with the image url generated by the SSD1B API
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
>>> dalle3 = SSD1B()
|
||||||
|
>>> task = "A painting of a dog"
|
||||||
|
>>> image_url = dalle3(task)
|
||||||
|
>>> print(image_url)
|
||||||
|
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
|
||||||
|
"""
|
||||||
|
if self.dashboard:
|
||||||
|
self.print_dashboard()
|
||||||
|
if task in self.cache:
|
||||||
|
return self.cache[task]
|
||||||
|
try:
|
||||||
|
img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0]
|
||||||
|
|
||||||
|
# Generate a unique filename for the image
|
||||||
|
img_name = f"{uuid.uuid4()}.{self.image_format}"
|
||||||
|
img_path = os.path.join(self.save_path, img_name)
|
||||||
|
|
||||||
|
# Save the image
|
||||||
|
img.save(img_path, self.image_format)
|
||||||
|
self.cache[task] = img_path
|
||||||
|
|
||||||
|
return img_path
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
# Handling exceptions and printing the errors details
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
(
|
||||||
|
f"Error running SSD1B: {error} try optimizing your api key and"
|
||||||
|
" or try again"
|
||||||
|
),
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def _generate_image_name(self, task: str):
|
||||||
|
"""Generate a sanitized file name based on the task"""
|
||||||
|
sanitized_task = "".join(
|
||||||
|
char for char in task if char.isalnum() or char in " _ -"
|
||||||
|
).rstrip()
|
||||||
|
return f"{sanitized_task}.{self.image_format}"
|
||||||
|
|
||||||
|
def _download_image(self, img: Image, filename: str):
|
||||||
|
"""
|
||||||
|
Save the PIL Image object to a file.
|
||||||
|
"""
|
||||||
|
full_path = os.path.join(self.save_path, filename)
|
||||||
|
img.save(full_path, self.image_format)
|
||||||
|
|
||||||
|
def print_dashboard(self):
|
||||||
|
"""Print the SSD1B dashboard"""
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
(
|
||||||
|
f"""SSD1B Dashboard:
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
Model: {self.model}
|
||||||
|
Image: {self.img}
|
||||||
|
Size: {self.size}
|
||||||
|
Max Retries: {self.max_retries}
|
||||||
|
Quality: {self.quality}
|
||||||
|
N: {self.n}
|
||||||
|
Save Path: {self.save_path}
|
||||||
|
Time Seconds: {self.time_seconds}
|
||||||
|
Save Folder: {self.save_folder}
|
||||||
|
Image Format: {self.image_format}
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
"green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Process a batch of tasks concurrently
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (List[str]): A list of tasks to be processed
|
||||||
|
max_workers (int): The maximum number of workers to use for the concurrent processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
results (List[str]): A list of image urls generated by the SSD1B API
|
||||||
|
|
||||||
|
Example:
|
||||||
|
--------
|
||||||
|
>>> model = SSD1B()
|
||||||
|
>>> tasks = ["A painting of a dog", "A painting of a cat"]
|
||||||
|
>>> results = model.process_batch_concurrently(tasks)
|
||||||
|
>>> print(results)
|
||||||
|
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
future_to_task = {executor.submit(self, task): task for task in tasks}
|
||||||
|
results = []
|
||||||
|
for future in concurrent.futures.as_completed(future_to_task):
|
||||||
|
task = future_to_task[future]
|
||||||
|
try:
|
||||||
|
img = future.result()
|
||||||
|
results.append(img)
|
||||||
|
|
||||||
|
print(f"Task {task} completed: {img}")
|
||||||
|
except Exception as error:
|
||||||
|
print(
|
||||||
|
colored(
|
||||||
|
(
|
||||||
|
f"Error running SSD1B: {error} try optimizing your api key and"
|
||||||
|
" or try again"
|
||||||
|
),
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(colored(f"Error running SSD1B: {error.http_status}", "red"))
|
||||||
|
print(colored(f"Error running SSD1B: {error.error}", "red"))
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def _generate_uuid(self):
|
||||||
|
"""Generate a uuid"""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""Repr method for the SSD1B class"""
|
||||||
|
return f"SSD1B(image_url={self.image_url})"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Str method for the SSD1B class"""
|
||||||
|
return f"SSD1B(image_url={self.image_url})"
|
||||||
|
|
||||||
|
@backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
|
||||||
|
def rate_limited_call(self, task: str):
|
||||||
|
"""Rate limited call to the SSD1B API"""
|
||||||
|
return self.__call__(task)
|
@ -0,0 +1,97 @@
|
|||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Yi34B200k:
|
||||||
|
"""
|
||||||
|
A class for eaasy interaction with Yi34B200k
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
model_id: str
|
||||||
|
The model id of the model to be used.
|
||||||
|
device_map: str
|
||||||
|
The device to be used for inference.
|
||||||
|
torch_dtype: str
|
||||||
|
The torch dtype to be used for inference.
|
||||||
|
max_length: int
|
||||||
|
The maximum length of the generated text.
|
||||||
|
repitition_penalty: float
|
||||||
|
The repitition penalty to be used for inference.
|
||||||
|
no_repeat_ngram_size: int
|
||||||
|
The no repeat ngram size to be used for inference.
|
||||||
|
temperature: float
|
||||||
|
The temperature to be used for inference.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
--------
|
||||||
|
__call__(self, task: str) -> str:
|
||||||
|
Generates text based on the given prompt.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str = "01-ai/Yi-34B-200K",
|
||||||
|
device_map: str = "auto",
|
||||||
|
torch_dtype: str = "auto",
|
||||||
|
max_length: int = 512,
|
||||||
|
repitition_penalty: float = 1.3,
|
||||||
|
no_repeat_ngram_size: int = 5,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
top_k: int = 40,
|
||||||
|
top_p: float = 0.8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_id = model_id
|
||||||
|
self.device_map = device_map
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.max_length = max_length
|
||||||
|
self.repitition_penalty = repitition_penalty
|
||||||
|
self.no_repeat_ngram_size = no_repeat_ngram_size
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_k = top_k
|
||||||
|
self.top_p = top_p
|
||||||
|
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
device_map=device_map,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, task: str):
|
||||||
|
"""
|
||||||
|
Generates text based on the given prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The input text prompt.
|
||||||
|
max_length (int): The maximum length of the generated text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated text.
|
||||||
|
"""
|
||||||
|
inputs = self.tokenizer(task, return_tensors="pt")
|
||||||
|
outputs = self.model.generate(
|
||||||
|
inputs.input_ids.cuda(),
|
||||||
|
max_length=self.max_length,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
do_sample=True,
|
||||||
|
repetition_penalty=self.repitition_penalty,
|
||||||
|
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_k=self.top_k,
|
||||||
|
top_p=self.top_p,
|
||||||
|
)
|
||||||
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage
|
||||||
|
# yi34b = Yi34B200k()
|
||||||
|
# prompt = "There's a place where time stands still. A place of breathtaking wonder, but also"
|
||||||
|
# generated_text = yi34b(prompt)
|
||||||
|
# print(generated_text)
|
Loading…
Reference in new issue