parent
2f1bf1fb46
commit
641b466c67
@ -0,0 +1,17 @@
|
|||||||
|
[package]
|
||||||
|
name = "engine"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2018"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "engine"
|
||||||
|
path = "src/my_lib.rs"
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
pyo3 = { version = "0.15", features = ["extension-module"] }
|
||||||
|
rayon = "1.5.1"
|
||||||
|
log = "0.4.14"
|
||||||
|
rustcuda = "0.1.0"
|
||||||
|
rustcuda_derive = "*"
|
||||||
|
rustcuda_core = "0.1"
|
@ -1,24 +0,0 @@
|
|||||||
# Text embeddings, image embeddings, and multimodal embeddings
|
|
||||||
# Add text and image embeddings into postgresl database
|
|
||||||
|
|
||||||
from swarms.models.jina_embeds import JinaEmbeddings
|
|
||||||
from swarms.models.gigabind import Gigabind
|
|
||||||
|
|
||||||
# Model
|
|
||||||
model = JinaEmbeddings(
|
|
||||||
max_length=8192,
|
|
||||||
device="cuda",
|
|
||||||
quantize=True,
|
|
||||||
huggingface_api_key="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",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Encode text
|
|
||||||
|
|
||||||
embeddings = model("Encode this super long document text")
|
|
||||||
|
|
||||||
|
|
||||||
# Embed images or text
|
|
||||||
model = Gigabind()
|
|
||||||
|
|
||||||
multi_modal_embeddings = model(text=[text], imgs=[img1, img2, img3])
|
|
@ -0,0 +1,21 @@
|
|||||||
|
from swarms import Agent, OpenAIChat, MajorityVoting
|
||||||
|
|
||||||
|
# Initialize the llm
|
||||||
|
llm = OpenAIChat()
|
||||||
|
|
||||||
|
# Initialize the agents
|
||||||
|
agent1 = Agent(llm=llm, max_loops=1)
|
||||||
|
agent2 = Agent(llm=llm, max_loops=1)
|
||||||
|
agent3 = Agent(llm=llm, max_loops=1)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the majority voting
|
||||||
|
mv = MajorityVoting(
|
||||||
|
agents=[agent1, agent2, agent3],
|
||||||
|
concurrent=True,
|
||||||
|
multithreaded=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Start the majority voting
|
||||||
|
mv.run("What is the capital of France?")
|
@ -0,0 +1,93 @@
|
|||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::wrap_pyfunction;
|
||||||
|
use pyo3::types::IntoPyDict;
|
||||||
|
use rayon::{ThreadPool, ThreadPoolBuilder, prelude::*};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
|
||||||
|
#[pymodule]
|
||||||
|
fn rust_module(py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_function(wrap_pyfunction!(concurrent_exec, m)?)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
pub fn concurrent_exec<F, G, H>(
|
||||||
|
py_codes: Vec<&str>,
|
||||||
|
timeout: Option<Duration>,
|
||||||
|
num_threads: usize,
|
||||||
|
error_handler: F,
|
||||||
|
log_function: G,
|
||||||
|
result_handler: H,
|
||||||
|
) -> PyResult<Vec<PyResult<()>>>
|
||||||
|
|
||||||
|
/// This function wraps Python code in Rust concurrency for ultra high performance.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `py_codes` - A vector of string slices that holds the Python codes to be executed.
|
||||||
|
/// * `timeout` - An optional duration to specify a timeout for the Python code execution.
|
||||||
|
/// * `num_threads` - The number of threads to use for executing the Python code.
|
||||||
|
/// * `error_handler` - A function to handle errors during Python code execution.
|
||||||
|
/// * `log_function` - A function to log the execution of the Python code.
|
||||||
|
/// * `result_handler` - A function to handle the results of the Python code execution.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// let py_codes = vec!["print('Hello, World!')", "print('Hello, Rust!')"];
|
||||||
|
/// let timeout = Some(Duration::from_secs(5));
|
||||||
|
/// let num_threads = 4;
|
||||||
|
/// let error_handler = |e| eprintln!("Error: {}", e);
|
||||||
|
/// let log_function = |s| println!("Log: {}", s);
|
||||||
|
/// let result_handler = |r| println!("Result: {:?}", r);
|
||||||
|
/// execute_python_codes(py_codes, timeout, num_threads, error_handler, log_function, result_handler);
|
||||||
|
/// ```
|
||||||
|
where
|
||||||
|
F: Fn(&str),
|
||||||
|
G: Fn(&str),
|
||||||
|
H: Fn(&PyResult<()>),
|
||||||
|
{
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
let py_codes = Arc::new(Mutex::new(py_codes));
|
||||||
|
let results = Arc::new(Mutex::new(Vec::new()));
|
||||||
|
let pool = ThreadPool::new(num_threads);
|
||||||
|
|
||||||
|
pool.install(|| {
|
||||||
|
py_codes.par_iter().for_each(|code| {
|
||||||
|
let locals = [("__name__", "__main__")].into_py_dict(py);
|
||||||
|
let globals = [("__name__", "__main__")].into_py_dict(py);
|
||||||
|
|
||||||
|
log_function(&format!("Executing Python code: {}", code));
|
||||||
|
let result = py.run(code, Some(globals), Some(locals));
|
||||||
|
|
||||||
|
match timeout {
|
||||||
|
Some(t) => {
|
||||||
|
let now = Instant::now();
|
||||||
|
let timeout_thread = thread::spawn(move || {
|
||||||
|
while now.elapsed() < t {
|
||||||
|
if let Ok(_) = result {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if now.elapsed() >= t {
|
||||||
|
error_handler(&format!("Python code execution timed out: {}", code));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
timeout_thread.join().unwrap();
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
results.lock().unwrap().push(result.clone());
|
||||||
|
result_handler(&result);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
pool.join();
|
||||||
|
Ok(results.lock().unwrap().clone())
|
||||||
|
}
|
@ -0,0 +1,71 @@
|
|||||||
|
use pyo3::prelude::*;
|
||||||
|
use rustacuda::prelude::*;
|
||||||
|
use rustacuda::memory::DeviceBox;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::ffi::CString;
|
||||||
|
|
||||||
|
#[pymodule]
|
||||||
|
fn rust_cuda(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
#[pyfn(m, "execute_on_device")]
|
||||||
|
fn execute_on_device(py: Python, device_id: u32, a: f32, b: f32) -> PyResult<f32> {
|
||||||
|
/// The result of executing the CUDA operation.
|
||||||
|
let result = py.allow_threads(|| {
|
||||||
|
execute_cuda(device_id, a, b)
|
||||||
|
});
|
||||||
|
match result {
|
||||||
|
Ok(res) => Ok(res),
|
||||||
|
Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!("{}", err))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn execute_cuda(device_id: u32, a: f32, b: f32) -> Result<f32, Box<dyn Error>> {
|
||||||
|
rustacuda::init(CudaFlags::empty())?;
|
||||||
|
let device = Device::get_device(device_id)?;
|
||||||
|
/// Creates a new CUDA context and pushes it onto the current thread's stack.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `flags` - The flags to be used when creating the context.
|
||||||
|
/// * `device` - The device on which the context will be created.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// The newly created CUDA context.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the context creation fails.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use swarms::cuda_wrapper::Context;
|
||||||
|
///
|
||||||
|
/// let device = 0;
|
||||||
|
/// let context = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device)?;
|
||||||
|
/// ```
|
||||||
|
pub fn create_and_push(flags: ContextFlags, device: i32) -> Result<Context, CudaError> {
|
||||||
|
// implementation goes here
|
||||||
|
}
|
||||||
|
let context = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device)?;
|
||||||
|
let module_data = CString::new(include_str!("../resources/add.ptx"))?;
|
||||||
|
let module = Module::load_from_string(&module_data)?;
|
||||||
|
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
|
||||||
|
let mut x = DeviceBox::new(&a)?;
|
||||||
|
let mut y = DeviceBox::new(&b)?;
|
||||||
|
let mut result = DeviceBox::new(&0.0f32)?;
|
||||||
|
unsafe {
|
||||||
|
launch!(module.sum<<<1, 1, 0, stream>>>(
|
||||||
|
x.as_device_ptr(),
|
||||||
|
y.as_device_ptr(),
|
||||||
|
result.as_device_ptr(),
|
||||||
|
1
|
||||||
|
))?;
|
||||||
|
}
|
||||||
|
stream.synchronize()?;
|
||||||
|
let mut result_host = 0.0f32;
|
||||||
|
result.copy_to(&mut result_host)?;
|
||||||
|
Ok(result_host)
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
use std::fs::File;
|
||||||
|
use std::io::prelude::*;
|
||||||
|
use std::time::Instant;
|
||||||
|
use std::io::{BufReader, io};
|
||||||
|
use ranyon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||||
|
|
||||||
|
fn read_file(path: &str) -> Vec<String> {
|
||||||
|
/// Reads the contents of a file located at the specified path.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `path` - The path to the file.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A `Result` containing a vector of strings representing the lines of the file if the file was successfully read,
|
||||||
|
/// or an `io::Error` if there was an error reading the file.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use std::io;
|
||||||
|
/// use std::fs::File;
|
||||||
|
/// use std::io::BufReader;
|
||||||
|
///
|
||||||
|
/// fn read_file(path: &str) -> io::Result<Vec<String>> {
|
||||||
|
/// let contents: io::Result<Vec<String>> = BufReader::new(File::open(path).expect("Could not open file"))
|
||||||
|
/// .lines()
|
||||||
|
/// .collect();
|
||||||
|
/// contents
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
let contents: io::Result<Vec<String>> = BufReader::new(File::open(path).expect("Could not open file"))
|
||||||
|
.lines()
|
||||||
|
.collect();
|
||||||
|
return contents.expect("Could not read file");
|
||||||
|
}
|
@ -0,0 +1,110 @@
|
|||||||
|
/// This module provides a multi-threading processor for executing Python modules and functions in parallel.
|
||||||
|
/// It utilizes the `rayon` crate for parallel processing and the `pyo3` crate for interacting with the Python interpreter.
|
||||||
|
/// The `multithreading_processor` function takes a vector of `PythonModule` structs and the number of threads to use.
|
||||||
|
/// Each `PythonModule` struct contains the name of the Python module, the name of the function to call, and any arguments to pass to the function.
|
||||||
|
/// The function imports the Python module, calls the specified function, and sends any errors encountered back to the main thread.
|
||||||
|
/// If an import error occurs, a `PythonError::ImportError` is returned.
|
||||||
|
/// If a function call error occurs, a `PythonError::FunctionError` is returned.
|
||||||
|
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::wrap_pyfunction;
|
||||||
|
use rayon::prelude::*;
|
||||||
|
use std::sync::mpsc::{channel, Sender};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use log::{info, error};
|
||||||
|
|
||||||
|
struct PythonModule<'a> {
|
||||||
|
name: &'a str,
|
||||||
|
function: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum PythonError {
|
||||||
|
ImportError(String),
|
||||||
|
FunctionError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn my_module(py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_function(wrap_pyfunction!(process_python_modules, m)?)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn process_python_modules(modules: Vec<PythonModule>, num_threads: usize) -> Result<(), PythonError> {
|
||||||
|
/// The function returns `Ok(())` if all modules are processed successfully.
|
||||||
|
/// Note: This code assumes that the necessary dependencies (`pyo3`, `rayon`, `log`) are already imported and initialized.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `modules` - A vector of `PythonModule` structs representing the Python modules and functions to execute.
|
||||||
|
/// * `num_threads` - The number of threads to use for parallel processing.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use pyo3::types::PyModule;
|
||||||
|
/// use pyo3::types::PyResult;
|
||||||
|
/// use pyo3::prelude::*;
|
||||||
|
///
|
||||||
|
/// struct PythonModule<'a> {
|
||||||
|
/// name: &'a str,
|
||||||
|
/// function: &'a str,
|
||||||
|
/// args: Vec<&'a str>,
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// #[pymodule]
|
||||||
|
/// fn multithreading_processor(modules: Vec<PythonModule>, num_threads: usize) -> Result<(), PythonError> {
|
||||||
|
/// // Function implementation
|
||||||
|
/// Ok(())
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns a `PythonError` if an import error or a function call error occurs.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// This function does not panic.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// This function is safe to call, but it assumes that the necessary dependencies (`pyo3`, `rayon`, `log`) are already imported and initialized.
|
||||||
|
// Initialize Python interpreter
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
|
||||||
|
// Set the global thread pool's configuration
|
||||||
|
rayon::ThreadPoolBuilder::new()
|
||||||
|
.num_threads(num_threads)
|
||||||
|
.build_global()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Create a channel to send errors from threads to the main thread
|
||||||
|
let (tx, rx) = channel();
|
||||||
|
let tx = Arc::new(Mutex::new(tx));
|
||||||
|
|
||||||
|
// Process each Python module in parallel
|
||||||
|
modules.par_iter().for_each(|module| {
|
||||||
|
let result = PyModule::import(py, module.name)
|
||||||
|
.map_err(|_| PythonError::ImportError(module.name.to_string()))
|
||||||
|
.and_then(|m| m.call0(module.function)
|
||||||
|
.map_err(|_| PythonError::FunctionError(module.function.to_string())));
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
let tx = tx.lock().unwrap();
|
||||||
|
tx.send(e).unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
drop(tx); // Close the sender
|
||||||
|
for error in rx {
|
||||||
|
match error {
|
||||||
|
PythonError::ImportError(module) => error!("Failed to import module {}", module),
|
||||||
|
PythonError::FunctionError(function) => error!("Failed to call function {}", function),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -0,0 +1,197 @@
|
|||||||
|
import json
|
||||||
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.conversation import Conversation
|
||||||
|
from swarms.utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiAgentStructure:
|
||||||
|
"""
|
||||||
|
Base class for a multi-agent structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agents (List[Agent], optional): List of agents in the structure. Defaults to None.
|
||||||
|
callbacks (Optional[Sequence[callable]], optional): List of callbacks for the structure. Defaults to None.
|
||||||
|
autosave (bool, optional): Flag indicating whether to enable autosave. Defaults to False.
|
||||||
|
logging (bool, optional): Flag indicating whether to enable logging. Defaults to False.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agents (List[Agent]): List of agents in the structure.
|
||||||
|
callbacks (Optional[Sequence[callable]]): List of callbacks for the structure.
|
||||||
|
autosave (bool): Flag indicating whether autosave is enabled.
|
||||||
|
logging (bool): Flag indicating whether logging is enabled.
|
||||||
|
conversation (Conversation): Conversation object for the structure.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
metadata(): Get the metadata of the multi-agent structure.
|
||||||
|
save_to_json(filename: str): Save the current state of the multi-agent structure to a JSON file.
|
||||||
|
load_from_json(filename: str): Load the state of the multi-agent structure from a JSON file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agents: List[Agent] = None,
|
||||||
|
callbacks: Optional[Sequence[callable]] = None,
|
||||||
|
autosave: bool = False,
|
||||||
|
logging: bool = False,
|
||||||
|
return_metadata: bool = False,
|
||||||
|
metadata_filename: str = "multiagent_structure_metadata.json",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.agents = agents
|
||||||
|
self.callbacks = callbacks
|
||||||
|
self.autosave = autosave
|
||||||
|
self.logging = logging
|
||||||
|
self.return_metadata = return_metadata
|
||||||
|
self.metadata_filename = metadata_filename
|
||||||
|
self.conversation = Conversation(
|
||||||
|
time_enabled=True, *args, **kwargs
|
||||||
|
)
|
||||||
|
if self.logging:
|
||||||
|
self.logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle the case where the agents are not provided
|
||||||
|
# Handle agents
|
||||||
|
for agent in self.agents:
|
||||||
|
if not isinstance(agent, Agent):
|
||||||
|
raise TypeError("Agents must be of type Agent.")
|
||||||
|
|
||||||
|
if self.agents is None:
|
||||||
|
self.agents = []
|
||||||
|
|
||||||
|
# Handle the case where the callbacks are not provided
|
||||||
|
if self.callbacks is None:
|
||||||
|
self.callbacks = []
|
||||||
|
|
||||||
|
# Handle the case where the autosave is not provided
|
||||||
|
if self.autosave is None:
|
||||||
|
self.autosave = False
|
||||||
|
|
||||||
|
# Handle the case where the logging is not provided
|
||||||
|
if self.logging is None:
|
||||||
|
self.logging = False
|
||||||
|
|
||||||
|
# Handle callbacks
|
||||||
|
if callbacks is not None:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
if not callable(callback):
|
||||||
|
raise TypeError("Callback must be callable.")
|
||||||
|
|
||||||
|
# Handle autosave
|
||||||
|
if autosave:
|
||||||
|
self.save_to_json(metadata_filename)
|
||||||
|
|
||||||
|
def metadata(self):
|
||||||
|
"""
|
||||||
|
Get the metadata of the multi-agent structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The metadata of the multi-agent structure.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"agents": self.agents,
|
||||||
|
"callbacks": self.callbacks,
|
||||||
|
"autosave": self.autosave,
|
||||||
|
"logging": self.logging,
|
||||||
|
"conversation": self.conversation,
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_to_json(self, filename: str):
|
||||||
|
"""
|
||||||
|
Save the current state of the multi-agent structure to a JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The name of the file to save the multi-agent structure to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(self.__dict__, f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def load_from_json(self, filename: str):
|
||||||
|
"""
|
||||||
|
Load the state of the multi-agent structure from a JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The name of the file to load the multi-agent structure from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
self.__dict__ = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def save_to_yaml(self, filename: str):
|
||||||
|
"""
|
||||||
|
Save the current state of the multi-agent structure to a YAML file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The name of the file to save the multi-agent structure to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
yaml.dump(self.__dict__, f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def load_from_yaml(self, filename: str):
|
||||||
|
"""
|
||||||
|
Load the state of the multi-agent structure from a YAML file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The name of the file to load the multi-agent structure from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
self.__dict__ = yaml.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}({self.__dict__})"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.__class__.__name__}({self.__dict__})"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.agents)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.agents[index]
|
||||||
|
|
||||||
|
def __setitem__(self, index, value):
|
||||||
|
self.agents[index] = value
|
||||||
|
|
||||||
|
def __delitem__(self, index):
|
||||||
|
del self.agents[index]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.agents)
|
||||||
|
|
||||||
|
def __reversed__(self):
|
||||||
|
return reversed(self.agents)
|
||||||
|
|
||||||
|
def __contains__(self, value):
|
||||||
|
return value in self.agents
|
@ -0,0 +1,366 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
NAME_LIST = [
|
||||||
|
"Affirmative side",
|
||||||
|
"Negative side",
|
||||||
|
"Moderator",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DebatePlayer(Agent):
|
||||||
|
def __init__(self, llm, name: str, *args, **kwargs) -> None:
|
||||||
|
"""Create a player in the debate
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name(str): model name
|
||||||
|
name (str): name of this player
|
||||||
|
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
|
||||||
|
openai_api_key (str): As the parameter name suggests
|
||||||
|
sleep_time (float): sleep because of rate limits
|
||||||
|
"""
|
||||||
|
super(DebatePlayer, self).__init__(
|
||||||
|
llm=llm, agent_name=name, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Debate:
|
||||||
|
"""Create a debate
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): openai model name
|
||||||
|
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
|
||||||
|
num_players (int): num of players
|
||||||
|
save_file_dir (str): dir path to json file
|
||||||
|
openai_api_key (str): As the parameter name suggests
|
||||||
|
prompts_path (str): prompts path (json file)
|
||||||
|
max_round (int): maximum Rounds of Debate
|
||||||
|
sleep_time (float): sleep because of rate limits
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
debate_agents: List[DebatePlayer],
|
||||||
|
temperature: float = 0,
|
||||||
|
num_players: int = 3,
|
||||||
|
save_file_dir: str = None,
|
||||||
|
prompts_path: str = None,
|
||||||
|
max_round: int = 3,
|
||||||
|
sleep_time: float = 0,
|
||||||
|
) -> None:
|
||||||
|
self.debate_agents = debate_agents
|
||||||
|
self.num_players = num_players
|
||||||
|
self.save_file_dir = save_file_dir
|
||||||
|
self.max_round = max_round
|
||||||
|
self.sleep_time = sleep_time
|
||||||
|
|
||||||
|
# init save file
|
||||||
|
now = datetime.now()
|
||||||
|
current_time = now.strftime("%Y-%m-%d_%H:%M:%S")
|
||||||
|
self.save_file = {
|
||||||
|
"start_time": current_time,
|
||||||
|
"end_time": "",
|
||||||
|
"temperature": temperature,
|
||||||
|
"num_players": num_players,
|
||||||
|
"success": False,
|
||||||
|
"src_lng": "",
|
||||||
|
"tgt_lng": "",
|
||||||
|
"source": "",
|
||||||
|
"reference": "",
|
||||||
|
"base_translation": "",
|
||||||
|
"debate_translation": "",
|
||||||
|
"Reason": "",
|
||||||
|
"Supported Side": "",
|
||||||
|
"players": {},
|
||||||
|
}
|
||||||
|
prompts = json.load(open(prompts_path))
|
||||||
|
self.save_file.update(prompts)
|
||||||
|
self.init_prompt()
|
||||||
|
|
||||||
|
if self.save_file["base_translation"] == "":
|
||||||
|
self.create_base()
|
||||||
|
|
||||||
|
# creat&init agents
|
||||||
|
self.create_agents()
|
||||||
|
self.init_agents()
|
||||||
|
|
||||||
|
def init_prompt(self):
|
||||||
|
def prompt_replace(key):
|
||||||
|
self.save_file[key] = (
|
||||||
|
self.save_file[key]
|
||||||
|
.replace("##src_lng##", self.save_file["src_lng"])
|
||||||
|
.replace("##tgt_lng##", self.save_file["tgt_lng"])
|
||||||
|
.replace("##source##", self.save_file["source"])
|
||||||
|
.replace(
|
||||||
|
"##base_translation##",
|
||||||
|
self.save_file["base_translation"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_replace("base_prompt")
|
||||||
|
prompt_replace("player_meta_prompt")
|
||||||
|
prompt_replace("moderator_meta_prompt")
|
||||||
|
prompt_replace("judge_prompt_last2")
|
||||||
|
|
||||||
|
def create_base(self):
|
||||||
|
print(
|
||||||
|
"\n===== Translation Task"
|
||||||
|
f" =====\n{self.save_file['base_prompt']}\n"
|
||||||
|
)
|
||||||
|
agent = DebatePlayer(
|
||||||
|
name="Baseline",
|
||||||
|
openai_api_key=self.openai_api_key,
|
||||||
|
)
|
||||||
|
agent.add_message_to_memory(self.save_file["base_prompt"])
|
||||||
|
base_translation = agent.ask()
|
||||||
|
agent.add_message_to_memory(base_translation)
|
||||||
|
self.save_file["base_translation"] = base_translation
|
||||||
|
self.save_file["affirmative_prompt"] = self.save_file[
|
||||||
|
"affirmative_prompt"
|
||||||
|
].replace("##base_translation##", base_translation)
|
||||||
|
self.save_file["players"][agent.name] = agent.memory_lst
|
||||||
|
|
||||||
|
def create_agents(self):
|
||||||
|
# creates players
|
||||||
|
self.players = [
|
||||||
|
DebatePlayer(
|
||||||
|
model_name=self.model_name,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
for name in NAME_LIST
|
||||||
|
]
|
||||||
|
self.affirmative = self.players[0]
|
||||||
|
self.negative = self.players[1]
|
||||||
|
self.moderator = self.players[2]
|
||||||
|
|
||||||
|
def init_agents(self):
|
||||||
|
# start: set meta prompt
|
||||||
|
self.affirmative.system_prompt(
|
||||||
|
self.save_file["player_meta_prompt"]
|
||||||
|
)
|
||||||
|
self.negative.system_prompt(
|
||||||
|
self.save_file["player_meta_prompt"]
|
||||||
|
)
|
||||||
|
self.moderator.system_prompt(
|
||||||
|
self.save_file["moderator_meta_prompt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# start: first round debate, state opinions
|
||||||
|
print("===== Debate Round-1 =====\n")
|
||||||
|
self.affirmative.add_message_to_memory(
|
||||||
|
self.save_file["affirmative_prompt"]
|
||||||
|
)
|
||||||
|
self.aff_ans = self.affirmative.ask()
|
||||||
|
self.affirmative.add_message_to_memory(self.aff_ans)
|
||||||
|
|
||||||
|
self.negative.add_message_to_memory(
|
||||||
|
self.save_file["negative_prompt"].replace(
|
||||||
|
"##aff_ans##", self.aff_ans
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.neg_ans = self.negative.ask()
|
||||||
|
self.negative.add_message_to_memory(self.neg_ans)
|
||||||
|
|
||||||
|
self.moderator.add_message_to_memory(
|
||||||
|
self.save_file["moderator_prompt"]
|
||||||
|
.replace("##aff_ans##", self.aff_ans)
|
||||||
|
.replace("##neg_ans##", self.neg_ans)
|
||||||
|
.replace("##round##", "first")
|
||||||
|
)
|
||||||
|
self.mod_ans = self.moderator.ask()
|
||||||
|
self.moderator.add_message_to_memory(self.mod_ans)
|
||||||
|
self.mod_ans = eval(self.mod_ans)
|
||||||
|
|
||||||
|
def round_dct(self, num: int):
|
||||||
|
dct = {
|
||||||
|
1: "first",
|
||||||
|
2: "second",
|
||||||
|
3: "third",
|
||||||
|
4: "fourth",
|
||||||
|
5: "fifth",
|
||||||
|
6: "sixth",
|
||||||
|
7: "seventh",
|
||||||
|
8: "eighth",
|
||||||
|
9: "ninth",
|
||||||
|
10: "tenth",
|
||||||
|
}
|
||||||
|
return dct[num]
|
||||||
|
|
||||||
|
def save_file_to_json(self, id):
|
||||||
|
now = datetime.now()
|
||||||
|
current_time = now.strftime("%Y-%m-%d_%H:%M:%S")
|
||||||
|
save_file_path = os.path.join(
|
||||||
|
self.save_file_dir, f"{id}.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.save_file["end_time"] = current_time
|
||||||
|
json_str = json.dumps(
|
||||||
|
self.save_file, ensure_ascii=False, indent=4
|
||||||
|
)
|
||||||
|
with open(save_file_path, "w") as f:
|
||||||
|
f.write(json_str)
|
||||||
|
|
||||||
|
def broadcast(self, msg: str):
|
||||||
|
"""Broadcast a message to all players.
|
||||||
|
Typical use is for the host to announce public information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg (str): the message
|
||||||
|
"""
|
||||||
|
# print(msg)
|
||||||
|
for player in self.players:
|
||||||
|
player.add_message_to_memory(msg)
|
||||||
|
|
||||||
|
def speak(self, speaker: str, msg: str):
|
||||||
|
"""The speaker broadcast a message to all other players.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speaker (str): name of the speaker
|
||||||
|
msg (str): the message
|
||||||
|
"""
|
||||||
|
if not msg.startswith(f"{speaker}: "):
|
||||||
|
msg = f"{speaker}: {msg}"
|
||||||
|
# print(msg)
|
||||||
|
for player in self.players:
|
||||||
|
if player.name != speaker:
|
||||||
|
player.add_message_to_memory(msg)
|
||||||
|
|
||||||
|
def ask_and_speak(self, player: DebatePlayer):
|
||||||
|
ans = player.ask()
|
||||||
|
player.add_message_to_memory(ans)
|
||||||
|
self.speak(player.name, ans)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
for round in range(self.max_round - 1):
|
||||||
|
if self.mod_ans["debate_translation"] != "":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"===== Debate Round-{round+2} =====\n")
|
||||||
|
self.affirmative.add_message_to_memory(
|
||||||
|
self.save_file["debate_prompt"].replace(
|
||||||
|
"##oppo_ans##", self.neg_ans
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.aff_ans = self.affirmative.ask()
|
||||||
|
self.affirmative.add_message_to_memory(self.aff_ans)
|
||||||
|
|
||||||
|
self.negative.add_message_to_memory(
|
||||||
|
self.save_file["debate_prompt"].replace(
|
||||||
|
"##oppo_ans##", self.aff_ans
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.neg_ans = self.negative.ask()
|
||||||
|
self.negative.add_message_to_memory(self.neg_ans)
|
||||||
|
|
||||||
|
self.moderator.add_message_to_memory(
|
||||||
|
self.save_file["moderator_prompt"]
|
||||||
|
.replace("##aff_ans##", self.aff_ans)
|
||||||
|
.replace("##neg_ans##", self.neg_ans)
|
||||||
|
.replace("##round##", self.round_dct(round + 2))
|
||||||
|
)
|
||||||
|
self.mod_ans = self.moderator.ask()
|
||||||
|
self.moderator.add_message_to_memory(self.mod_ans)
|
||||||
|
self.mod_ans = eval(self.mod_ans)
|
||||||
|
|
||||||
|
if self.mod_ans["debate_translation"] != "":
|
||||||
|
self.save_file.update(self.mod_ans)
|
||||||
|
self.save_file["success"] = True
|
||||||
|
|
||||||
|
# ultimate deadly technique.
|
||||||
|
else:
|
||||||
|
judge_player = DebatePlayer(
|
||||||
|
model_name=self.model_name,
|
||||||
|
name="Judge",
|
||||||
|
temperature=self.temperature,
|
||||||
|
openai_api_key=self.openai_api_key,
|
||||||
|
sleep_time=self.sleep_time,
|
||||||
|
)
|
||||||
|
aff_ans = self.affirmative.memory_lst[2]["content"]
|
||||||
|
neg_ans = self.negative.memory_lst[2]["content"]
|
||||||
|
|
||||||
|
judge_player.system_prompt(
|
||||||
|
self.save_file["moderator_meta_prompt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# extract answer candidates
|
||||||
|
judge_player.add_message_to_memory(
|
||||||
|
self.save_file["judge_prompt_last1"]
|
||||||
|
.replace("##aff_ans##", aff_ans)
|
||||||
|
.replace("##neg_ans##", neg_ans)
|
||||||
|
)
|
||||||
|
ans = judge_player.ask()
|
||||||
|
judge_player.add_message_to_memory(ans)
|
||||||
|
|
||||||
|
# select one from the candidates
|
||||||
|
judge_player.add_message_to_memory(
|
||||||
|
self.save_file["judge_prompt_last2"]
|
||||||
|
)
|
||||||
|
ans = judge_player.ask()
|
||||||
|
judge_player.add_message_to_memory(ans)
|
||||||
|
|
||||||
|
ans = eval(ans)
|
||||||
|
if ans["debate_translation"] != "":
|
||||||
|
self.save_file["success"] = True
|
||||||
|
# save file
|
||||||
|
self.save_file.update(ans)
|
||||||
|
self.players.append(judge_player)
|
||||||
|
|
||||||
|
for player in self.players:
|
||||||
|
self.save_file["players"][player.name] = player.memory_lst
|
||||||
|
|
||||||
|
|
||||||
|
# def parse_args():
|
||||||
|
# parser = argparse.ArgumentParser("", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
|
# parser.add_argument("-i", "--input-file", type=str, required=True, help="Input file path")
|
||||||
|
# parser.add_argument("-o", "--output-dir", type=str, required=True, help="Output file dir")
|
||||||
|
# parser.add_argument("-lp", "--lang-pair", type=str, required=True, help="Language pair")
|
||||||
|
# parser.add_argument("-k", "--api-key", type=str, required=True, help="OpenAI api key")
|
||||||
|
# parser.add_argument("-m", "--model-name", type=str, default="gpt-3.5-turbo", help="Model name")
|
||||||
|
# parser.add_argument("-t", "--temperature", type=float, default=0, help="Sampling temperature")
|
||||||
|
|
||||||
|
# return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# args = parse_args()
|
||||||
|
# openai_api_key = args.api_key
|
||||||
|
|
||||||
|
# current_script_path = os.path.abspath(__file__)
|
||||||
|
# MAD_path = current_script_path.rsplit("/", 2)[0]
|
||||||
|
|
||||||
|
# src_lng, tgt_lng = args.lang_pair.split('-')
|
||||||
|
# src_full = Language.make(language=src_lng).display_name()
|
||||||
|
# tgt_full = Language.make(language=tgt_lng).display_name()
|
||||||
|
|
||||||
|
# config = json.load(open(f"{MAD_path}/code/utils/config4tran.json", "r"))
|
||||||
|
|
||||||
|
# inputs = open(args.input_file, "r").readlines()
|
||||||
|
# inputs = [l.strip() for l in inputs]
|
||||||
|
|
||||||
|
# save_file_dir = args.output_dir
|
||||||
|
# if not os.path.exists(save_file_dir):
|
||||||
|
# os.mkdir(save_file_dir)
|
||||||
|
|
||||||
|
# for id, input in enumerate(tqdm(inputs)):
|
||||||
|
# # files = os.listdir(save_file_dir)
|
||||||
|
# # if f"{id}.json" in files:
|
||||||
|
# # continue
|
||||||
|
|
||||||
|
# prompts_path = f"{save_file_dir}/{id}-config.json"
|
||||||
|
|
||||||
|
# config['source'] = input.split('\t')[0]
|
||||||
|
# config['reference'] = input.split('\t')[1]
|
||||||
|
# config['src_lng'] = src_full
|
||||||
|
# config['tgt_lng'] = tgt_full
|
||||||
|
|
||||||
|
# with open(prompts_path, 'w') as file:
|
||||||
|
# json.dump(config, file, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
# debate = Debate(save_file_dir=save_file_dir, num_players=3, openai_api_key=openai_api_key, prompts_path=prompts_path, temperature=0, sleep_time=0)
|
||||||
|
# debate.run()
|
||||||
|
# debate.save_file_to_json(id)
|
@ -0,0 +1,93 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.conversation import Conversation
|
||||||
|
from swarms.utils.logger import logger
|
||||||
|
from swarms.structs.base_multiagent_structure import (
|
||||||
|
BaseMultiAgentStructure,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StackOverflowSwarm(BaseMultiAgentStructure):
|
||||||
|
"""
|
||||||
|
Represents a swarm of agents that work together to solve a problem or answer a question on Stack Overflow.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agents (List[Agent]): The list of agents in the swarm.
|
||||||
|
autosave (bool): Flag indicating whether to automatically save the conversation.
|
||||||
|
verbose (bool): Flag indicating whether to display verbose output.
|
||||||
|
save_filepath (str): The filepath to save the conversation.
|
||||||
|
conversation (Conversation): The conversation object for storing the interactions.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from swarms.structs.agent import Agent
|
||||||
|
>>> from swarms.structs.stack_overflow_swarm import StackOverflowSwarm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agents: List[Agent],
|
||||||
|
autosave: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
save_filepath: str = "stack_overflow_swarm.json",
|
||||||
|
eval_agent: Agent = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.agents = agents
|
||||||
|
self.autosave = autosave
|
||||||
|
self.verbose = verbose
|
||||||
|
self.save_filepath = save_filepath
|
||||||
|
self.eval_agent = eval_agent
|
||||||
|
|
||||||
|
# Configure conversation
|
||||||
|
self.conversation = Conversation(
|
||||||
|
time_enabled=True,
|
||||||
|
autosave=autosave,
|
||||||
|
save_filepath=save_filepath,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Counter for the number of upvotes per post
|
||||||
|
self.upvotes = 0
|
||||||
|
|
||||||
|
# Counter for the number of downvotes per post
|
||||||
|
self.downvotes = 0
|
||||||
|
|
||||||
|
# Forum for the agents to interact
|
||||||
|
self.forum = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the swarm to solve a problem or answer a question like stack overflow
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be performed by the agents.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: The conversation history.
|
||||||
|
"""
|
||||||
|
# Add the task to the conversation
|
||||||
|
self.conversation.add("Human", task)
|
||||||
|
logger.info(f"Task: {task} Added to the Forum.")
|
||||||
|
|
||||||
|
# Run the agents and get their responses and append to the conversation
|
||||||
|
for agent in self.agents:
|
||||||
|
response = agent.run(
|
||||||
|
self.conversation.return_history_as_string(),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Add to the conversation
|
||||||
|
self.conversation.add(
|
||||||
|
agent.ai_name, f"{response}"
|
||||||
|
)
|
||||||
|
logger.info(f"[{agent.ai_name}]: [{response}]")
|
||||||
|
|
||||||
|
return self.conversation.return_history_as_string()
|
@ -0,0 +1,26 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use pyo3::types::IntoPyDict;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_execute_on_device() {
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
|
||||||
|
// Define a Python module for testing
|
||||||
|
let rust_cuda = PyModule::new(py, "rust_cuda").unwrap();
|
||||||
|
rust_cuda.add_function(wrap_pyfunction!(execute_on_device, rust_cuda).unwrap()).unwrap();
|
||||||
|
|
||||||
|
// Test the execute_on_device function
|
||||||
|
let result: PyResult<f32> = rust_cuda.call1("execute_on_device", (0, 1.0f32, 2.0f32)).unwrap().extract().unwrap();
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_execute_cuda() {
|
||||||
|
// Test the execute_cuda function
|
||||||
|
let result = execute_cuda(0, 1.0f32, 2.0f32);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,62 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use pyo3::types::IntoPyDict;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_process_python_modules() {
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
|
||||||
|
// Define a Python module for testing
|
||||||
|
let code = r#"
|
||||||
|
def test_function():
|
||||||
|
return "Hello, World!"
|
||||||
|
"#;
|
||||||
|
let test_module = PyModule::new(py, "test_module").unwrap();
|
||||||
|
test_module.add_function(wrap_pyfunction!(test_function, test_module).unwrap()).unwrap();
|
||||||
|
test_module.add(py, "test_function", code).unwrap();
|
||||||
|
|
||||||
|
// Define a PythonModule struct for testing
|
||||||
|
let test_python_module = PythonModule {
|
||||||
|
name: "test_module",
|
||||||
|
function: "test_function",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test the process_python_modules function
|
||||||
|
let result = process_python_modules(vec![test_python_module], 1);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_process_python_modules_import_error() {
|
||||||
|
// Define a PythonModule struct with a non-existent module
|
||||||
|
let test_python_module = PythonModule {
|
||||||
|
name: "non_existent_module",
|
||||||
|
function: "test_function",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test the process_python_modules function
|
||||||
|
let result = process_python_modules(vec![test_python_module], 1);
|
||||||
|
assert!(matches!(result, Err(PythonError::ImportError(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_process_python_modules_function_error() {
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
|
||||||
|
// Define a Python module for testing
|
||||||
|
let test_module = PyModule::new(py, "test_module").unwrap();
|
||||||
|
|
||||||
|
// Define a PythonModule struct with a non-existent function
|
||||||
|
let test_python_module = PythonModule {
|
||||||
|
name: "test_module",
|
||||||
|
function: "non_existent_function",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test the process_python_modules function
|
||||||
|
let result = process_python_modules(vec![test_python_module], 1);
|
||||||
|
assert!(matches!(result, Err(PythonError::FunctionError(_))));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue