parent
1e7514f98e
commit
d7cdb6c917
@ -1,380 +0,0 @@
|
|||||||
"""
|
|
||||||
Evaluate model performance using vLLM and unsloth.
|
|
||||||
|
|
||||||
This script evaluates the performance of a model using vLLM for fast inference
|
|
||||||
and unsloth for LoRA support.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from unsloth import FastLanguageModel
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
from src import (
|
|
||||||
apply_chat_template,
|
|
||||||
build_reward_correctness_fn,
|
|
||||||
build_user_prompt,
|
|
||||||
get_qa_dataset,
|
|
||||||
get_system_prompt,
|
|
||||||
run_eval,
|
|
||||||
)
|
|
||||||
from config import MODEL_NAME, logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config():
|
|
||||||
"""Get model configuration."""
|
|
||||||
return {
|
|
||||||
"max_seq_length": 4096 * 2,
|
|
||||||
"lora_rank": 64,
|
|
||||||
"gpu_memory_utilization": 0.6,
|
|
||||||
"model_name": MODEL_NAME,
|
|
||||||
"target_modules": [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_sampling_params(temperature=0.5):
|
|
||||||
"""Get sampling parameters for generation."""
|
|
||||||
return SamplingParams(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=4096,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_tokenizer():
|
|
||||||
"""Initialize model and tokenizer with LoRA support."""
|
|
||||||
config = get_model_config()
|
|
||||||
logger.info(f"Setting up model {config['model_name']} with LoRA support...")
|
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
||||||
model_name=config["model_name"],
|
|
||||||
max_seq_length=config["max_seq_length"],
|
|
||||||
load_in_4bit=True,
|
|
||||||
fast_inference=True,
|
|
||||||
max_lora_rank=config["lora_rank"],
|
|
||||||
gpu_memory_utilization=config["gpu_memory_utilization"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup LoRA
|
|
||||||
model = FastLanguageModel.get_peft_model(
|
|
||||||
model,
|
|
||||||
r=config["lora_rank"],
|
|
||||||
target_modules=config["target_modules"],
|
|
||||||
lora_alpha=config["lora_rank"],
|
|
||||||
use_gradient_checkpointing=True,
|
|
||||||
random_state=3407,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Model and tokenizer setup complete.")
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
lora_path=None,
|
|
||||||
temperature=0.5,
|
|
||||||
output_file="eval_results.txt",
|
|
||||||
trainer_dir=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Evaluate model with or without LoRA weights.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to evaluate
|
|
||||||
tokenizer: The tokenizer
|
|
||||||
lora_path: Path to LoRA weights (None for base model)
|
|
||||||
temperature: Sampling temperature
|
|
||||||
output_file: File to write results to
|
|
||||||
trainer_dir: Directory containing the checkpoints
|
|
||||||
"""
|
|
||||||
sampling_params = get_sampling_params(temperature=temperature)
|
|
||||||
|
|
||||||
# Set up output directory
|
|
||||||
if trainer_dir:
|
|
||||||
eval_log_dir = os.path.join(trainer_dir, "eval_logs")
|
|
||||||
else:
|
|
||||||
eval_log_dir = "eval_logs"
|
|
||||||
os.makedirs(eval_log_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Create file names based on model type
|
|
||||||
model_prefix = "lora" if lora_path else "base"
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
|
|
||||||
# Define all output file paths
|
|
||||||
eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log")
|
|
||||||
output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt")
|
|
||||||
debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json")
|
|
||||||
|
|
||||||
logger.info(f"Writing evaluation log to: {eval_log_file}")
|
|
||||||
logger.info(f"Results will be saved to: {output_file}")
|
|
||||||
|
|
||||||
# Function to generate completions using agentic approach
|
|
||||||
def eval_generate_fn(inputs):
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Format inputs as chat messages with system prompt
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": get_system_prompt()},
|
|
||||||
{"role": "user", "content": build_user_prompt(input_text)},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
for input_text in inputs
|
|
||||||
]
|
|
||||||
|
|
||||||
if lora_path:
|
|
||||||
lora_request = model.load_lora(lora_path)
|
|
||||||
load_time = time.time() - start_time
|
|
||||||
logger.info(f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}")
|
|
||||||
responses = model.fast_generate(
|
|
||||||
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
lora_request=lora_request,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
responses = model.fast_generate(
|
|
||||||
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
gen_time = time.time() - start_time
|
|
||||||
logger.debug(f"Generation completed in {gen_time:.2f} seconds")
|
|
||||||
return responses
|
|
||||||
|
|
||||||
def verifier_generate_fn(inputs):
|
|
||||||
# Use a lower temperature for verification to get more consistent results
|
|
||||||
verifier_params = get_sampling_params(temperature=0.1)
|
|
||||||
|
|
||||||
# Format inputs as chat messages with system prompt
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": get_system_prompt()},
|
|
||||||
{"role": "user", "content": build_user_prompt(input_text)},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
for input_text in inputs
|
|
||||||
]
|
|
||||||
|
|
||||||
return model.fast_generate(
|
|
||||||
[apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages],
|
|
||||||
sampling_params=verifier_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare the verification function
|
|
||||||
verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer)
|
|
||||||
|
|
||||||
# Get the dataset and prepare questions and answers
|
|
||||||
train_dataset, test_dataset = get_qa_dataset()
|
|
||||||
questions = test_dataset["prompt"]
|
|
||||||
inputs = questions
|
|
||||||
|
|
||||||
logger.info(f"Verifying {len(inputs)} answers...")
|
|
||||||
|
|
||||||
# Run the evaluation
|
|
||||||
start_time = time.time()
|
|
||||||
model_type = "LoRA" if lora_path else "Base"
|
|
||||||
logger.info(f"Starting {model_type} model evaluation...")
|
|
||||||
|
|
||||||
# Run evaluation using the agentic approach
|
|
||||||
full_chat_states = run_eval(
|
|
||||||
generate_fn=eval_generate_fn,
|
|
||||||
verify_fn=verify_fn,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
output_file=output_file,
|
|
||||||
debug_file=debug_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate rewards
|
|
||||||
logger.info(f"Calculating rewards for {model_type} model...")
|
|
||||||
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
|
|
||||||
avg_reward = sum(rewards) / len(rewards)
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Record the results
|
|
||||||
results = {
|
|
||||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
"model_type": model_type,
|
|
||||||
"model_name": MODEL_NAME,
|
|
||||||
"lora_path": lora_path if lora_path else "None",
|
|
||||||
"accuracy": avg_reward,
|
|
||||||
"correct_count": sum(rewards),
|
|
||||||
"total_count": len(rewards),
|
|
||||||
"temperature": temperature,
|
|
||||||
"time_taken": total_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add more detailed output to log file
|
|
||||||
logger.info(f"\n{'=' * 50}")
|
|
||||||
logger.info(f"{model_type.upper()} MODEL EVALUATION RESULTS:")
|
|
||||||
logger.info(f"{'=' * 50}")
|
|
||||||
logger.info(f"Accuracy: {avg_reward:.4f} ({sum(rewards)}/{len(rewards)} correct)")
|
|
||||||
logger.info(f"Temperature: {temperature}")
|
|
||||||
logger.info(f"Time taken: {total_time:.2f} seconds")
|
|
||||||
logger.info(f"Results file: {output_file}")
|
|
||||||
logger.info(f"Debug file: {debug_file}")
|
|
||||||
logger.info(f"Log file: {eval_log_file}")
|
|
||||||
|
|
||||||
# Write a summary to the log file too
|
|
||||||
with open(eval_log_file, "a") as f:
|
|
||||||
f.write(f"\n{'=' * 50}\n")
|
|
||||||
f.write(f"{model_type.upper()} MODEL EVALUATION SUMMARY\n")
|
|
||||||
f.write(f"{'=' * 50}\n")
|
|
||||||
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
||||||
f.write(f"Accuracy: {avg_reward:.4f} ({sum(rewards)}/{len(rewards)} correct)\n")
|
|
||||||
f.write(f"Temperature: {temperature}\n")
|
|
||||||
f.write(f"Time taken: {total_time:.2f} seconds\n")
|
|
||||||
f.write(f"Results saved to: {output_file}\n")
|
|
||||||
f.write(f"Debug data saved to: {debug_file}\n\n")
|
|
||||||
|
|
||||||
logger.info(f"Evaluation completed. Results saved to {output_file} and {debug_file}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=None):
|
|
||||||
"""
|
|
||||||
Compare base model with LoRA model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lora_path: Path to LoRA weights
|
|
||||||
temperature: Sampling temperature
|
|
||||||
output_file: File to write results to (optional)
|
|
||||||
trainer_dir: Directory containing the trainer output
|
|
||||||
"""
|
|
||||||
# Set up output directory
|
|
||||||
if trainer_dir:
|
|
||||||
eval_log_dir = os.path.join(trainer_dir, "eval_logs")
|
|
||||||
else:
|
|
||||||
eval_log_dir = "eval_logs"
|
|
||||||
os.makedirs(eval_log_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Define the comparison file path if not provided
|
|
||||||
if output_file is None:
|
|
||||||
output_file = os.path.join(eval_log_dir, "model_comparison_results.txt")
|
|
||||||
|
|
||||||
# Define file paths for individual model results
|
|
||||||
base_output = os.path.join(eval_log_dir, "base_model_results.txt")
|
|
||||||
lora_output = os.path.join(eval_log_dir, "lora_model_results.txt")
|
|
||||||
|
|
||||||
model, tokenizer = setup_model_and_tokenizer()
|
|
||||||
|
|
||||||
# Evaluate both models
|
|
||||||
base_results = evaluate_model(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
lora_path=None,
|
|
||||||
temperature=temperature,
|
|
||||||
output_file=base_output,
|
|
||||||
trainer_dir=trainer_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_results = evaluate_model(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
lora_path=lora_path,
|
|
||||||
temperature=temperature,
|
|
||||||
output_file=lora_output,
|
|
||||||
trainer_dir=trainer_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate improvement
|
|
||||||
improvement = lora_results["accuracy"] - base_results["accuracy"]
|
|
||||||
|
|
||||||
# Write comparison results
|
|
||||||
with open(output_file, "w") as f:
|
|
||||||
f.write("MODEL COMPARISON RESULTS\n")
|
|
||||||
f.write("======================\n\n")
|
|
||||||
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
||||||
f.write(f"Base Model: {MODEL_NAME}\n")
|
|
||||||
f.write(f"LoRA Path: {lora_path}\n\n")
|
|
||||||
f.write(f"Base Model Accuracy: {base_results['accuracy']:.4f}\n")
|
|
||||||
f.write(f"LoRA Model Accuracy: {lora_results['accuracy']:.4f}\n")
|
|
||||||
f.write(f"Improvement: {improvement:.4f}\n")
|
|
||||||
f.write(f"Temperature: {temperature}\n")
|
|
||||||
f.write(f"Base Model Time: {base_results['time_taken']:.2f}s\n")
|
|
||||||
f.write(f"LoRA Model Time: {lora_results['time_taken']:.2f}s\n\n")
|
|
||||||
f.write(f"Base Model Results File: {base_output}\n")
|
|
||||||
f.write(f"LoRA Model Results File: {lora_output}\n")
|
|
||||||
|
|
||||||
logger.info("\nModel comparison completed.")
|
|
||||||
logger.info(f"\n{'=' * 50}")
|
|
||||||
logger.info("MODEL COMPARISON RESULTS:")
|
|
||||||
logger.info(f"{'=' * 50}")
|
|
||||||
logger.info(f"Base Model Accuracy: {base_results['accuracy']:.4f}")
|
|
||||||
logger.info(f"LoRA Model Accuracy: {lora_results['accuracy']:.4f}")
|
|
||||||
logger.info(f"Improvement: {improvement:.4f}")
|
|
||||||
logger.info(f"Temperature: {temperature}")
|
|
||||||
logger.info(f"Results written to: {output_file}")
|
|
||||||
logger.info(f"Base Model Results: {base_output}")
|
|
||||||
logger.info(f"LoRA Model Results: {lora_output}")
|
|
||||||
logger.info(f"{'=' * 50}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"base_accuracy": base_results["accuracy"],
|
|
||||||
"lora_accuracy": lora_results["accuracy"],
|
|
||||||
"improvement": improvement,
|
|
||||||
"output_file": output_file,
|
|
||||||
"base_output": base_output,
|
|
||||||
"lora_output": lora_output,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Evaluate and compare models")
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora_path",
|
|
||||||
type=str,
|
|
||||||
default="trainer_output_example/checkpoint-101",
|
|
||||||
help="Path to LoRA weights",
|
|
||||||
)
|
|
||||||
parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature")
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_file",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="File to write results to (optional)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--trainer_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Directory containing the trainer output",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
logger.info(f"Starting model evaluation with temperature {args.temperature}")
|
|
||||||
results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=args.trainer_dir)
|
|
||||||
if results:
|
|
||||||
logger.info("Evaluation completed successfully")
|
|
||||||
logger.info(f"Final improvement: {results['improvement']:.4f}")
|
|
||||||
logger.info(f"Results saved to: {results['output_file']}")
|
|
||||||
|
|
||||||
# Print all output files for clarity
|
|
||||||
logger.info("\nSUMMARY OF OUTPUT FILES:")
|
|
||||||
logger.info(f"Comparison results: {results['output_file']}")
|
|
||||||
logger.info(f"Base model results: {results['base_output']}")
|
|
||||||
logger.info(f"LoRA model results: {results['lora_output']}")
|
|
||||||
|
|
||||||
# Find and print all log files in the eval_logs directory
|
|
||||||
eval_log_dir = os.path.join(args.trainer_dir, "eval_logs") if args.trainer_dir else "eval_logs"
|
|
||||||
if os.path.exists(eval_log_dir):
|
|
||||||
log_files = [f for f in os.listdir(eval_log_dir) if f.endswith(".log")]
|
|
||||||
if log_files:
|
|
||||||
logger.info("\nEVALUATION LOG FILES:")
|
|
||||||
for log_file in log_files:
|
|
||||||
logger.info(f"- {os.path.join(eval_log_dir, log_file)}")
|
|
||||||
else:
|
|
||||||
logger.warning("Evaluation failed or was skipped")
|
|
@ -1,458 +0,0 @@
|
|||||||
"""
|
|
||||||
Simple CLI inference script with search functionality.
|
|
||||||
|
|
||||||
This script allows interaction with the merged 16-bit model
|
|
||||||
and provides search functionality for data retrieval.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
from src import (
|
|
||||||
apply_chat_template,
|
|
||||||
build_user_prompt,
|
|
||||||
extract_search_query,
|
|
||||||
format_search_results,
|
|
||||||
get_system_prompt,
|
|
||||||
)
|
|
||||||
from src.search_module import load_vectorstore, search
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_tokenizer(model_path: str):
|
|
||||||
"""Initialize model and tokenizer."""
|
|
||||||
print(f"Setting up model from {model_path}...")
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
torch_dtype="float16",
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
print("Model and tokenizer setup complete.")
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> SamplingParams:
|
|
||||||
"""Get sampling parameters for generation."""
|
|
||||||
return SamplingParams(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSearchCLI:
|
|
||||||
"""CLI for interacting with the model and search functionality."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_path: str,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the CLI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to the merged 16-bit model
|
|
||||||
temperature: Sampling temperature
|
|
||||||
system_prompt: Optional system prompt to guide the model's behavior
|
|
||||||
"""
|
|
||||||
self.model, self.tokenizer = setup_model_and_tokenizer(model_path)
|
|
||||||
self.temperature = temperature
|
|
||||||
self.sampling_params = get_sampling_params(temperature)
|
|
||||||
self.history = []
|
|
||||||
self.search_history = []
|
|
||||||
self.system_prompt = system_prompt or get_system_prompt()
|
|
||||||
|
|
||||||
def _run_agent_generation(self, chat_state: dict) -> dict:
|
|
||||||
"""Run a single generation step for the agent."""
|
|
||||||
# Format the chat state using the same template as training
|
|
||||||
formatted_prompt = apply_chat_template(chat_state, tokenizer=self.tokenizer)["text"]
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
|
|
||||||
outputs = self.model.generate(
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=self.sampling_params.max_tokens,
|
|
||||||
temperature=self.sampling_params.temperature,
|
|
||||||
top_p=self.sampling_params.top_p,
|
|
||||||
do_sample=True,
|
|
||||||
)
|
|
||||||
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
gen_time = time.time() - start_time
|
|
||||||
print(f"Generation completed in {gen_time:.2f} seconds")
|
|
||||||
|
|
||||||
# Extract assistant response
|
|
||||||
assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
|
|
||||||
|
|
||||||
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
|
|
||||||
|
|
||||||
return chat_state
|
|
||||||
|
|
||||||
def generate(self, prompt: str, max_generations: int = 20) -> str:
|
|
||||||
"""
|
|
||||||
Generate a response to the prompt using agentic mechanism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to generate a response to
|
|
||||||
max_generations: Maximum number of turns in the conversation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The generated response after completing the conversation
|
|
||||||
"""
|
|
||||||
# Initialize chat state with the same structure as training
|
|
||||||
chat_state = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": self.system_prompt},
|
|
||||||
{"role": "user", "content": build_user_prompt(prompt)},
|
|
||||||
],
|
|
||||||
"finished": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Agent loop
|
|
||||||
for i in range(max_generations):
|
|
||||||
# Generate response
|
|
||||||
chat_state = self._run_agent_generation(chat_state)
|
|
||||||
|
|
||||||
# Check if conversation is finished
|
|
||||||
chat_state = self._check_finished_chat(chat_state)
|
|
||||||
if chat_state.get("finished"):
|
|
||||||
break
|
|
||||||
|
|
||||||
# Process tool calls if any
|
|
||||||
chat_state = self._run_tool_calls(chat_state)
|
|
||||||
|
|
||||||
# Get final response
|
|
||||||
final_response = chat_state["messages"][-1]["content"]
|
|
||||||
|
|
||||||
# Update history
|
|
||||||
self.history.append({"role": "user", "content": prompt})
|
|
||||||
self.history.append({"role": "assistant", "content": final_response})
|
|
||||||
|
|
||||||
return final_response
|
|
||||||
|
|
||||||
def _check_finished_chat(self, chat_state: dict) -> dict:
|
|
||||||
"""Check if the chat is finished (no more search queries)."""
|
|
||||||
if chat_state.get("finished"):
|
|
||||||
return chat_state
|
|
||||||
|
|
||||||
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
|
|
||||||
|
|
||||||
assistant_response = chat_state["messages"][-1]["content"]
|
|
||||||
search_query = extract_search_query(assistant_response)
|
|
||||||
|
|
||||||
if not search_query:
|
|
||||||
chat_state["finished"] = True
|
|
||||||
|
|
||||||
return chat_state
|
|
||||||
|
|
||||||
def _run_tool_calls(self, chat_state: dict) -> dict:
|
|
||||||
"""Execute tool calls found in chat state."""
|
|
||||||
if chat_state.get("finished"):
|
|
||||||
return chat_state
|
|
||||||
|
|
||||||
try:
|
|
||||||
assistant_response = chat_state["messages"][-1]["content"]
|
|
||||||
search_query = extract_search_query(assistant_response)
|
|
||||||
|
|
||||||
if search_query:
|
|
||||||
print(f"🔍 Search Query: {search_query}")
|
|
||||||
|
|
||||||
results = search(search_query, return_type=str, results=2)
|
|
||||||
# Wrap results in <information> tags
|
|
||||||
formatted_results = f"<information>{results}</information>"
|
|
||||||
|
|
||||||
# Print search results to terminal
|
|
||||||
print("\n===== SEARCH RESULTS =====")
|
|
||||||
print(results)
|
|
||||||
print("===========================\n")
|
|
||||||
|
|
||||||
chat_state["messages"].append({"role": "ipython", "content": formatted_results})
|
|
||||||
|
|
||||||
# Record search in history
|
|
||||||
search_entry = {
|
|
||||||
"turn": len(self.history) // 2,
|
|
||||||
"searches": [{"query": search_query, "results": results}],
|
|
||||||
}
|
|
||||||
self.search_history.append(search_entry)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during tool call: {str(e)}")
|
|
||||||
chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
|
|
||||||
chat_state["finished"] = True
|
|
||||||
|
|
||||||
return chat_state
|
|
||||||
|
|
||||||
def clear_history(self):
|
|
||||||
"""Clear the conversation history."""
|
|
||||||
self.history = []
|
|
||||||
self.search_history = []
|
|
||||||
print("Conversation history cleared.")
|
|
||||||
|
|
||||||
def set_system_prompt(self, prompt: str):
|
|
||||||
"""
|
|
||||||
Set a new system prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The new system prompt
|
|
||||||
"""
|
|
||||||
if not prompt:
|
|
||||||
print("System prompt cannot be empty. Using default.")
|
|
||||||
return
|
|
||||||
|
|
||||||
self.system_prompt = prompt
|
|
||||||
print("System prompt updated.")
|
|
||||||
print(f"New system prompt: {self.system_prompt}")
|
|
||||||
|
|
||||||
def display_welcome(self):
|
|
||||||
"""Display welcome message."""
|
|
||||||
print(f"\n{'=' * 50}")
|
|
||||||
print(f"DeepSearch CLI - {self.model.name_or_path}")
|
|
||||||
print(f"Model: {self.model.name_or_path}")
|
|
||||||
print(f"Temperature: {self.temperature}")
|
|
||||||
print(f"System Prompt: {self.system_prompt}")
|
|
||||||
print(f"{'=' * 50}")
|
|
||||||
print("Type 'help' to see available commands.")
|
|
||||||
|
|
||||||
def print_pretty_chat_history(self):
|
|
||||||
"""Print the full chat history in a pretty format, including searches."""
|
|
||||||
if not self.history:
|
|
||||||
print("No chat history available.")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("CHAT HISTORY WITH SEARCH DETAILS")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
# Group history into conversation turns
|
|
||||||
for i in range(0, len(self.history), 2):
|
|
||||||
turn_number = i // 2
|
|
||||||
|
|
||||||
# Print user message
|
|
||||||
if i < len(self.history):
|
|
||||||
user_msg = self.history[i]["content"]
|
|
||||||
print(f"\n[Turn {turn_number + 1}] USER: ")
|
|
||||||
print("-" * 40)
|
|
||||||
print(user_msg)
|
|
||||||
|
|
||||||
# Print searches associated with this turn if any
|
|
||||||
for search_entry in self.search_history:
|
|
||||||
if search_entry["turn"] == turn_number:
|
|
||||||
for idx, search in enumerate(search_entry["searches"]):
|
|
||||||
print(f'\n🔍 SEARCH {idx + 1}: "{search["query"]}"')
|
|
||||||
print("-" * 40)
|
|
||||||
print(search["results"])
|
|
||||||
|
|
||||||
# Print assistant response
|
|
||||||
if i + 1 < len(self.history):
|
|
||||||
assistant_msg = self.history[i + 1]["content"]
|
|
||||||
print(f"\n[Turn {turn_number + 1}] ASSISTANT: ")
|
|
||||||
print("-" * 40)
|
|
||||||
print(assistant_msg)
|
|
||||||
|
|
||||||
print("\n" + "=" * 80 + "\n")
|
|
||||||
|
|
||||||
def save_chat_history(self, filepath=None):
|
|
||||||
"""
|
|
||||||
Save chat history to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath: Path to save file (if None, auto-generate based on timestamp)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the saved file
|
|
||||||
"""
|
|
||||||
if not self.history:
|
|
||||||
print("No chat history to save.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate a default filepath if none provided
|
|
||||||
if filepath is None:
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.txt")
|
|
||||||
|
|
||||||
# Ensure the directory exists
|
|
||||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
||||||
|
|
||||||
# Prepare chat history data
|
|
||||||
pretty_history = []
|
|
||||||
|
|
||||||
# Group history into conversation turns
|
|
||||||
for i in range(0, len(self.history), 2):
|
|
||||||
turn_number = i // 2
|
|
||||||
turn_data = {
|
|
||||||
"turn": turn_number + 1,
|
|
||||||
"user": self.history[i]["content"] if i < len(self.history) else "",
|
|
||||||
"searches": [],
|
|
||||||
"assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add searches for this turn
|
|
||||||
for search_entry in self.search_history:
|
|
||||||
if search_entry["turn"] == turn_number:
|
|
||||||
turn_data["searches"].extend(search_entry["searches"])
|
|
||||||
|
|
||||||
pretty_history.append(turn_data)
|
|
||||||
|
|
||||||
# Write to file
|
|
||||||
try:
|
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
|
||||||
f.write(f"{'=' * 80}\n")
|
|
||||||
f.write("DEEPSEARCH CHAT HISTORY\n")
|
|
||||||
f.write(f"Model: {self.model.name_or_path}\n")
|
|
||||||
f.write(f"Temperature: {self.temperature}\n")
|
|
||||||
f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
||||||
f.write(f"{'=' * 80}\n\n")
|
|
||||||
|
|
||||||
for turn in pretty_history:
|
|
||||||
f.write(f"[Turn {turn['turn']}] USER:\n")
|
|
||||||
f.write(f"{'-' * 40}\n")
|
|
||||||
f.write(f"{turn['user']}\n\n")
|
|
||||||
|
|
||||||
# Write searches
|
|
||||||
for i, search in enumerate(turn["searches"]):
|
|
||||||
f.write(f'🔍 SEARCH {i + 1}: "{search["query"]}"\n')
|
|
||||||
f.write(f"{'-' * 40}\n")
|
|
||||||
f.write(f"{search['results']}\n\n")
|
|
||||||
|
|
||||||
f.write(f"[Turn {turn['turn']}] ASSISTANT:\n")
|
|
||||||
f.write(f"{'-' * 40}\n")
|
|
||||||
f.write(f"{turn['assistant']}\n\n")
|
|
||||||
f.write(f"{'=' * 40}\n\n")
|
|
||||||
|
|
||||||
print(f"Chat history saved to: {filepath}")
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving chat history: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def display_help(self):
|
|
||||||
"""Display help information."""
|
|
||||||
print("\n===== Commands =====")
|
|
||||||
print("search <query> - Search for information")
|
|
||||||
print("system <prompt> - Set a new system prompt")
|
|
||||||
print("clear - Clear conversation history")
|
|
||||||
print("history - Display full chat history with searches")
|
|
||||||
print("save - Save chat history to a text file")
|
|
||||||
print("help - Display this help message")
|
|
||||||
print("exit/quit - Exit the program")
|
|
||||||
print("Any other input will be treated as a prompt to the model.")
|
|
||||||
print("===================\n")
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""Run the CLI."""
|
|
||||||
self.display_welcome()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
user_input = input("\n> ").strip()
|
|
||||||
|
|
||||||
if not user_input:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower() in ["exit", "quit"]:
|
|
||||||
print("Exiting...")
|
|
||||||
break
|
|
||||||
|
|
||||||
if user_input.lower() == "help":
|
|
||||||
self.display_help()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower() == "clear":
|
|
||||||
self.clear_history()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower() == "history":
|
|
||||||
self.print_pretty_chat_history()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower() == "save":
|
|
||||||
self.save_chat_history()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower().startswith("system "):
|
|
||||||
new_prompt = user_input[7:].strip()
|
|
||||||
self.set_system_prompt(new_prompt)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower().startswith("search "):
|
|
||||||
query = user_input[7:].strip()
|
|
||||||
if query:
|
|
||||||
try:
|
|
||||||
results = search(query, return_type=str)
|
|
||||||
formatted_results = format_search_results(results)
|
|
||||||
print(formatted_results)
|
|
||||||
|
|
||||||
# Add to search history
|
|
||||||
search_entry = {
|
|
||||||
"turn": len(self.history) // 2,
|
|
||||||
"searches": [{"query": query, "results": results}],
|
|
||||||
}
|
|
||||||
self.search_history.append(search_entry)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error searching: {e}")
|
|
||||||
else:
|
|
||||||
print("Please provide a search query.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process as a prompt to the model
|
|
||||||
print("\nGenerating response...")
|
|
||||||
response = self.generate(user_input)
|
|
||||||
print("\n----- Response -----")
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nExiting...")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function."""
|
|
||||||
parser = argparse.ArgumentParser(description="DeepSearch CLI")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default="trainer_output_example/model_merged_16bit",
|
|
||||||
help="Path to the merged 16-bit model (default: trainer_output_example/model_merged_16bit)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--temperature",
|
|
||||||
type=float,
|
|
||||||
default=0.7,
|
|
||||||
help="Sampling temperature (default: 0.7)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--system_prompt",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="System prompt to guide model behavior",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Initialize and run the CLI
|
|
||||||
cli = DeepSearchCLI(
|
|
||||||
model_path=args.model_path,
|
|
||||||
temperature=args.temperature,
|
|
||||||
system_prompt=args.system_prompt,
|
|
||||||
)
|
|
||||||
cli.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Ensure the vectorstore is loaded
|
|
||||||
if load_vectorstore() is None:
|
|
||||||
print("FAISS vectorstore could not be loaded. Search functionality may not work.")
|
|
||||||
|
|
||||||
main()
|
|
Loading…
Reference in new issue