This script is a bit dumb, but it worked. I'll update it later XDmain
parent
fe70896023
commit
abb18b10d8
@ -0,0 +1,792 @@
|
||||
"""
|
||||
Simple CLI inference script with search functionality.
|
||||
|
||||
This script allows interaction with a model (with optional LoRA adapter)
|
||||
and provides search functionality for data retrieval.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from unsloth import FastLanguageModel
|
||||
from vllm import SamplingParams
|
||||
|
||||
from src.config import MODEL_NAME, OUTPUT_DIR, logger
|
||||
from src.search_module import load_vectorstore, search
|
||||
|
||||
|
||||
def find_latest_checkpoint(search_dir=None):
|
||||
"""
|
||||
Find the latest checkpoint in the specified directory or OUTPUT_DIR.
|
||||
|
||||
Args:
|
||||
search_dir: Directory to search for checkpoints (default: OUTPUT_DIR)
|
||||
|
||||
Returns:
|
||||
Path to the latest checkpoint or None if no checkpoints found
|
||||
"""
|
||||
if search_dir is None:
|
||||
search_dir = "trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236"
|
||||
logger.info(f"No search directory provided, using default: {search_dir}")
|
||||
else:
|
||||
logger.info(f"Searching for checkpoints in: {search_dir}")
|
||||
|
||||
# Check if the directory exists first
|
||||
if not os.path.exists(search_dir):
|
||||
logger.warning(f"Search directory {search_dir} does not exist")
|
||||
return None
|
||||
|
||||
# First try to find checkpoints in the format checkpoint-{step}
|
||||
import glob
|
||||
|
||||
checkpoints = glob.glob(os.path.join(search_dir, "checkpoint-*"))
|
||||
|
||||
if checkpoints:
|
||||
# Extract checkpoint numbers and sort
|
||||
checkpoint_numbers = []
|
||||
for checkpoint in checkpoints:
|
||||
match = re.search(r"checkpoint-(\d+)$", checkpoint)
|
||||
if match:
|
||||
checkpoint_numbers.append((int(match.group(1)), checkpoint))
|
||||
|
||||
if checkpoint_numbers:
|
||||
# Sort by checkpoint number (descending)
|
||||
checkpoint_numbers.sort(reverse=True)
|
||||
latest = checkpoint_numbers[0][1]
|
||||
logger.info(f"Found latest checkpoint: {latest}")
|
||||
return latest
|
||||
|
||||
# If no checkpoints found, look for saved_adapter_{timestamp}.bin files
|
||||
adapter_files = glob.glob(os.path.join(search_dir, "saved_adapter_*.bin"))
|
||||
if adapter_files:
|
||||
# Sort by modification time (newest first)
|
||||
adapter_files.sort(key=os.path.getmtime, reverse=True)
|
||||
latest = adapter_files[0]
|
||||
logger.info(f"Found latest adapter file: {latest}")
|
||||
return latest
|
||||
|
||||
# If all else fails, look for any .bin files
|
||||
bin_files = glob.glob(os.path.join(search_dir, "*.bin"))
|
||||
if bin_files:
|
||||
# Sort by modification time (newest first)
|
||||
bin_files.sort(key=os.path.getmtime, reverse=True)
|
||||
latest = bin_files[0]
|
||||
logger.info(f"Found latest .bin file: {latest}")
|
||||
return latest
|
||||
|
||||
logger.warning(f"No checkpoints found in {search_dir}")
|
||||
return None
|
||||
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Initialize model and tokenizer with LoRA support."""
|
||||
config = {
|
||||
"max_seq_length": 4096 * 2,
|
||||
"lora_rank": 64,
|
||||
"gpu_memory_utilization": 0.6,
|
||||
"model_name": MODEL_NAME,
|
||||
"target_modules": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
}
|
||||
|
||||
logger.info(f"Setting up model {config['model_name']} with LoRA support...")
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=config["model_name"],
|
||||
max_seq_length=config["max_seq_length"],
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=config["lora_rank"],
|
||||
gpu_memory_utilization=config["gpu_memory_utilization"],
|
||||
)
|
||||
|
||||
# Setup LoRA
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=config["lora_rank"],
|
||||
target_modules=config["target_modules"],
|
||||
lora_alpha=config["lora_rank"],
|
||||
use_gradient_checkpointing=True,
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
logger.info("Model and tokenizer setup complete.")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> SamplingParams:
|
||||
"""Get sampling parameters for generation."""
|
||||
return SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=0.95,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
def extract_function_calls(text: str) -> List[Dict[str, Any]]:
|
||||
"""Extract function calls from a text."""
|
||||
import json
|
||||
import re
|
||||
|
||||
# Pattern to match JSON objects
|
||||
pattern = r"\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\}))*\}))*\}"
|
||||
json_matches = re.findall(pattern, text)
|
||||
|
||||
function_calls = []
|
||||
for json_str in json_matches:
|
||||
try:
|
||||
obj = json.loads(json_str)
|
||||
if "function" in obj:
|
||||
function_calls.append(obj)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return function_calls
|
||||
|
||||
|
||||
def build_user_prompt(q):
|
||||
"""
|
||||
Build a user prompt with the question and search tool definition.
|
||||
|
||||
Args:
|
||||
q (str): The question to ask
|
||||
|
||||
Returns:
|
||||
str: Formatted user prompt
|
||||
"""
|
||||
user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
|
||||
Given a question, answer it using by doing searches using the search_corpus tool.
|
||||
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
|
||||
|
||||
PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
|
||||
ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
|
||||
PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.
|
||||
|
||||
You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in <think> and </think> tags.
|
||||
|
||||
{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
|
||||
|
||||
Question: {q}
|
||||
"""
|
||||
return user_prompt
|
||||
|
||||
|
||||
def format_search_results(results: Union[str, List[str]]) -> str:
|
||||
"""
|
||||
Format search results for display.
|
||||
|
||||
Args:
|
||||
results: Search results as string or list of strings
|
||||
|
||||
Returns:
|
||||
Formatted search results
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
content = "\n".join([f"Result {i + 1}:\n{r}\n------" for i, r in enumerate(results)])
|
||||
else:
|
||||
content = results
|
||||
|
||||
return f"\n===== SEARCH RESULTS =====\n{content}\n===========================\n"
|
||||
|
||||
|
||||
class DeepSearchCLI:
|
||||
"""CLI for interacting with the model and search functionality."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_path: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the CLI.
|
||||
|
||||
Args:
|
||||
lora_path: Path to LoRA weights (None for base model)
|
||||
temperature: Sampling temperature
|
||||
system_prompt: Optional system prompt to guide the model's behavior
|
||||
"""
|
||||
self.model, self.tokenizer = setup_model_and_tokenizer()
|
||||
self.lora_path = lora_path
|
||||
self.temperature = temperature
|
||||
self.sampling_params = get_sampling_params(temperature)
|
||||
self.lora_request = None
|
||||
self.history = []
|
||||
self.search_history = []
|
||||
self.system_prompt = (
|
||||
system_prompt
|
||||
or f"""Cutting Knowledge Date: December 2023
|
||||
Today Date: {datetime.now().strftime("%d %b %Y")}
|
||||
|
||||
When you receive a tool call response, use the output to format an answer to the original user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities."""
|
||||
)
|
||||
|
||||
# Load LoRA if specified
|
||||
if self.lora_path:
|
||||
logger.info(f"Loading LoRA adapter from {self.lora_path}...")
|
||||
self.lora_request = self.model.load_lora(self.lora_path)
|
||||
if self.lora_request:
|
||||
logger.info(f"LoRA adapter loaded successfully: {self.lora_request}")
|
||||
else:
|
||||
logger.error("Failed to load LoRA adapter")
|
||||
|
||||
def generate(self, prompt: str, max_generations: int = 20) -> str:
|
||||
"""
|
||||
Generate a response to the prompt using agentic mechanism.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate a response to
|
||||
max_generations: Maximum number of turns in the conversation
|
||||
|
||||
Returns:
|
||||
The generated response after completing the conversation
|
||||
"""
|
||||
# Initialize chat state
|
||||
chat_state = {
|
||||
"messages": [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": build_user_prompt(prompt)},
|
||||
],
|
||||
"finished": False,
|
||||
}
|
||||
|
||||
# Agent loop
|
||||
for i in range(max_generations):
|
||||
# Generate response
|
||||
chat_state = self._run_agent_generation(chat_state)
|
||||
|
||||
# Check if conversation is finished
|
||||
chat_state = self._check_finished_chat(chat_state)
|
||||
if chat_state.get("finished"):
|
||||
break
|
||||
|
||||
# Process tool calls if any
|
||||
chat_state = self._run_tool_calls(chat_state)
|
||||
|
||||
# Get final response
|
||||
final_response = chat_state["messages"][-1]["content"]
|
||||
|
||||
# Update history
|
||||
self.history.append({"role": "user", "content": prompt})
|
||||
self.history.append({"role": "assistant", "content": final_response})
|
||||
|
||||
return final_response
|
||||
|
||||
def _run_agent_generation(self, chat_state: dict) -> dict:
|
||||
"""Run a single generation step for the agent."""
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
chat_state["messages"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
if self.lora_request:
|
||||
response = self.model.fast_generate(
|
||||
[formatted_prompt],
|
||||
sampling_params=self.sampling_params,
|
||||
lora_request=self.lora_request,
|
||||
)
|
||||
else:
|
||||
response = self.model.fast_generate(
|
||||
[formatted_prompt],
|
||||
sampling_params=self.sampling_params,
|
||||
)
|
||||
|
||||
gen_time = time.time() - start_time
|
||||
logger.debug(f"Generation completed in {gen_time:.2f} seconds")
|
||||
|
||||
if hasattr(response[0], "outputs"):
|
||||
response_text = response[0].outputs[0].text
|
||||
else:
|
||||
response_text = response[0]
|
||||
|
||||
# Extract assistant response
|
||||
assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
|
||||
|
||||
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
|
||||
|
||||
return chat_state
|
||||
|
||||
def _check_finished_chat(self, chat_state: dict) -> dict:
|
||||
"""Check if the chat is finished (no more function calls)."""
|
||||
if chat_state.get("finished"):
|
||||
return chat_state
|
||||
|
||||
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
|
||||
|
||||
assistant_response = chat_state["messages"][-1]["content"]
|
||||
function_calls = extract_json_objects(assistant_response)
|
||||
|
||||
if len(function_calls) == 0:
|
||||
chat_state["finished"] = True
|
||||
|
||||
return chat_state
|
||||
|
||||
def _run_tool_calls(self, chat_state: dict) -> dict:
|
||||
"""Execute tool calls found in chat state."""
|
||||
if chat_state.get("finished"):
|
||||
return chat_state
|
||||
|
||||
try:
|
||||
assistant_response = chat_state["messages"][-1]["content"]
|
||||
function_calls = extract_json_objects(assistant_response)
|
||||
|
||||
if len(function_calls) > 1:
|
||||
logger.warning("Multiple function calls found in assistant response")
|
||||
raise ValueError("Expected only one function call in assistant response")
|
||||
|
||||
elif len(function_calls) == 1:
|
||||
function_call = function_calls[0]
|
||||
query = function_call["function"]["parameters"]["query"]
|
||||
logger.info(f"🔍 Search Query: {query}")
|
||||
|
||||
results = search(query, return_type=str, results=2)
|
||||
|
||||
# Print search results to terminal
|
||||
# logger.info("\n===== SEARCH RESULTS =====")
|
||||
# logger.info(
|
||||
# results
|
||||
# ) # The results are already formatted with Result 1:, Result 2:, etc.
|
||||
# logger.info("===========================\n")
|
||||
|
||||
chat_state["messages"].append({"role": "ipython", "content": results})
|
||||
|
||||
# Record search in history
|
||||
search_entry = {
|
||||
"turn": len(self.history) // 2,
|
||||
"searches": [{"query": query, "results": results}],
|
||||
}
|
||||
self.search_history.append(search_entry)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during tool call: {str(e)}")
|
||||
chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
|
||||
chat_state["finished"] = True
|
||||
|
||||
return chat_state
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear the conversation history."""
|
||||
self.history = []
|
||||
self.search_history = []
|
||||
logger.info("Conversation history cleared.")
|
||||
|
||||
def set_system_prompt(self, prompt: str):
|
||||
"""
|
||||
Set a new system prompt.
|
||||
|
||||
Args:
|
||||
prompt: The new system prompt
|
||||
"""
|
||||
if not prompt:
|
||||
logger.warning("System prompt cannot be empty. Using default.")
|
||||
return
|
||||
|
||||
self.system_prompt = prompt
|
||||
logger.info("System prompt updated.")
|
||||
logger.info(f"New system prompt: {self.system_prompt}")
|
||||
|
||||
def display_welcome(self):
|
||||
"""Display welcome message."""
|
||||
model_type = "LoRA" if self.lora_path else "Base"
|
||||
logger.info(f"\n{'=' * 50}")
|
||||
logger.info(f"DeepSearch CLI - {model_type} Model")
|
||||
logger.info(f"Model: {MODEL_NAME}")
|
||||
logger.info(f"Temperature: {self.temperature}")
|
||||
if self.lora_path:
|
||||
logger.info(f"LoRA Path: {self.lora_path}")
|
||||
logger.info(f"System Prompt: {self.system_prompt}")
|
||||
logger.info(f"{'=' * 50}")
|
||||
logger.info("Type 'help' to see available commands.")
|
||||
|
||||
def print_pretty_chat_history(self):
|
||||
"""Print the full chat history in a pretty format, including searches."""
|
||||
if not self.history:
|
||||
logger.info("No chat history available.")
|
||||
return
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
logger.info("CHAT HISTORY WITH SEARCH DETAILS")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Group history into conversation turns
|
||||
for i in range(0, len(self.history), 2):
|
||||
turn_number = i // 2
|
||||
|
||||
# Print user message
|
||||
if i < len(self.history):
|
||||
user_msg = self.history[i]["content"]
|
||||
logger.info(f"\n[Turn {turn_number + 1}] USER: ")
|
||||
logger.info("-" * 40)
|
||||
logger.info(user_msg)
|
||||
|
||||
# Print searches associated with this turn if any
|
||||
for search_entry in self.search_history:
|
||||
if search_entry["turn"] == turn_number:
|
||||
for idx, search in enumerate(search_entry["searches"]):
|
||||
logger.info(f'\n🔍 SEARCH {idx + 1}: "{search["query"]}"')
|
||||
logger.info("-" * 40)
|
||||
logger.info(search["results"])
|
||||
|
||||
# Print assistant response
|
||||
if i + 1 < len(self.history):
|
||||
assistant_msg = self.history[i + 1]["content"]
|
||||
logger.info(f"\n[Turn {turn_number + 1}] ASSISTANT: ")
|
||||
logger.info("-" * 40)
|
||||
logger.info(assistant_msg)
|
||||
|
||||
logger.info("\n" + "=" * 80 + "\n")
|
||||
|
||||
def save_chat_history(self, filepath=None):
|
||||
"""
|
||||
Save chat history to a file.
|
||||
|
||||
Args:
|
||||
filepath: Path to save file (if None, auto-generate based on timestamp)
|
||||
|
||||
Returns:
|
||||
Path to the saved file
|
||||
"""
|
||||
if not self.history:
|
||||
logger.info("No chat history to save.")
|
||||
return None
|
||||
|
||||
# Generate a default filepath if none provided
|
||||
if filepath is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_type = "lora" if self.lora_path else "base"
|
||||
filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.txt")
|
||||
|
||||
# Ensure the directory exists
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
# Prepare chat history data
|
||||
pretty_history = []
|
||||
|
||||
# Group history into conversation turns
|
||||
for i in range(0, len(self.history), 2):
|
||||
turn_number = i // 2
|
||||
turn_data = {
|
||||
"turn": turn_number + 1,
|
||||
"user": self.history[i]["content"] if i < len(self.history) else "",
|
||||
"searches": [],
|
||||
"assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "",
|
||||
}
|
||||
|
||||
# Add searches for this turn
|
||||
for search_entry in self.search_history:
|
||||
if search_entry["turn"] == turn_number:
|
||||
turn_data["searches"].extend(search_entry["searches"])
|
||||
|
||||
pretty_history.append(turn_data)
|
||||
|
||||
# Write to file
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(f"{'=' * 80}\n")
|
||||
f.write(f"DEEPSEARCH CHAT HISTORY\n")
|
||||
f.write(f"Model: {MODEL_NAME}\n")
|
||||
f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n")
|
||||
f.write(f"Temperature: {self.temperature}\n")
|
||||
f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"{'=' * 80}\n\n")
|
||||
|
||||
for turn in pretty_history:
|
||||
f.write(f"[Turn {turn['turn']}] USER:\n")
|
||||
f.write(f"{'-' * 40}\n")
|
||||
f.write(f"{turn['user']}\n\n")
|
||||
|
||||
# Write searches
|
||||
for i, search in enumerate(turn["searches"]):
|
||||
f.write(f'🔍 SEARCH {i + 1}: "{search["query"]}"\n')
|
||||
f.write(f"{'-' * 40}\n")
|
||||
f.write(f"{search['results']}\n\n")
|
||||
|
||||
f.write(f"[Turn {turn['turn']}] ASSISTANT:\n")
|
||||
f.write(f"{'-' * 40}\n")
|
||||
f.write(f"{turn['assistant']}\n\n")
|
||||
f.write(f"{'=' * 40}\n\n")
|
||||
|
||||
logger.info(f"Chat history saved to: {filepath}")
|
||||
return filepath
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving chat history: {e}")
|
||||
return None
|
||||
|
||||
def save_chat_history_json(self, filepath=None):
|
||||
"""
|
||||
Save chat history to a JSON file.
|
||||
|
||||
Args:
|
||||
filepath: Path to save file (if None, auto-generate based on timestamp)
|
||||
|
||||
Returns:
|
||||
Path to the saved file
|
||||
"""
|
||||
if not self.history:
|
||||
logger.info("No chat history to save.")
|
||||
return None
|
||||
|
||||
# Generate a default filepath if none provided
|
||||
if filepath is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_type = "lora" if self.lora_path else "base"
|
||||
filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.json")
|
||||
|
||||
# Ensure the directory exists
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
# Prepare chat history data
|
||||
history_data = {
|
||||
"model": MODEL_NAME,
|
||||
"lora_path": self.lora_path if self.lora_path else None,
|
||||
"temperature": self.temperature,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
# Group history into conversation turns
|
||||
for i in range(0, len(self.history), 2):
|
||||
turn_number = i // 2
|
||||
turn_data = {
|
||||
"turn": turn_number + 1,
|
||||
"user": self.history[i]["content"] if i < len(self.history) else "",
|
||||
"searches": [],
|
||||
"assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "",
|
||||
}
|
||||
|
||||
# Add searches for this turn
|
||||
for search_entry in self.search_history:
|
||||
if search_entry["turn"] == turn_number:
|
||||
turn_data["searches"].extend(search_entry["searches"])
|
||||
|
||||
history_data["turns"].append(turn_data)
|
||||
|
||||
# Write to file
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(history_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Chat history saved to JSON: {filepath}")
|
||||
return filepath
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving chat history to JSON: {e}")
|
||||
return None
|
||||
|
||||
def display_help(self):
|
||||
"""Display help information."""
|
||||
logger.info("\n===== Commands =====")
|
||||
logger.info("search <query> - Search for information")
|
||||
logger.info("system <prompt> - Set a new system prompt")
|
||||
logger.info("clear - Clear conversation history")
|
||||
logger.info("history - Display full chat history with searches")
|
||||
logger.info("save - Save chat history to a text file")
|
||||
logger.info("savejson - Save chat history to a JSON file")
|
||||
logger.info("help - Display this help message")
|
||||
logger.info("exit/quit - Exit the program")
|
||||
logger.info("Any other input will be treated as a prompt to the model.")
|
||||
logger.info("===================\n")
|
||||
|
||||
def run(self):
|
||||
"""Run the CLI."""
|
||||
self.display_welcome()
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n> ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if user_input.lower() in ["exit", "quit"]:
|
||||
logger.info("Exiting...")
|
||||
break
|
||||
|
||||
if user_input.lower() == "help":
|
||||
self.display_help()
|
||||
continue
|
||||
|
||||
if user_input.lower() == "clear":
|
||||
self.clear_history()
|
||||
continue
|
||||
|
||||
if user_input.lower() == "history":
|
||||
self.print_pretty_chat_history()
|
||||
continue
|
||||
|
||||
if user_input.lower() == "save":
|
||||
self.save_chat_history()
|
||||
continue
|
||||
|
||||
if user_input.lower() == "savejson":
|
||||
self.save_chat_history_json()
|
||||
continue
|
||||
|
||||
if user_input.lower().startswith("system "):
|
||||
new_prompt = user_input[7:].strip()
|
||||
self.set_system_prompt(new_prompt)
|
||||
continue
|
||||
|
||||
if user_input.lower().startswith("search "):
|
||||
query = user_input[7:].strip()
|
||||
if query:
|
||||
try:
|
||||
results = search(query, return_type=str)
|
||||
formatted_results = format_search_results(results)
|
||||
logger.info(formatted_results)
|
||||
|
||||
# Add to search history
|
||||
search_entry = {
|
||||
"turn": len(self.history) // 2,
|
||||
"searches": [{"query": query, "results": results}],
|
||||
}
|
||||
self.search_history.append(search_entry)
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching: {e}")
|
||||
else:
|
||||
logger.warning("Please provide a search query.")
|
||||
continue
|
||||
|
||||
# Process as a prompt to the model
|
||||
logger.info("\nGenerating response...")
|
||||
response = self.generate(user_input)
|
||||
logger.info("\n----- Response -----")
|
||||
logger.info(response)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nExiting...")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
|
||||
def extract_json_objects(text):
|
||||
"""
|
||||
Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
|
||||
|
||||
Args:
|
||||
text (str): The input text possibly containing JSON objects.
|
||||
|
||||
Returns:
|
||||
list: A list of parsed JSON objects (dictionaries) extracted from the text.
|
||||
"""
|
||||
results = []
|
||||
length = len(text)
|
||||
i = 0
|
||||
|
||||
while i < length:
|
||||
# Look for the start of a JSON object
|
||||
if text[i] == "{":
|
||||
start = i
|
||||
stack = 1
|
||||
i += 1
|
||||
# Continue until we find the matching closing brace
|
||||
while i < length and stack > 0:
|
||||
if text[i] == "{":
|
||||
stack += 1
|
||||
elif text[i] == "}":
|
||||
stack -= 1
|
||||
i += 1
|
||||
# Only attempt to decode if the braces are balanced
|
||||
if stack == 0:
|
||||
candidate = text[start:i]
|
||||
try:
|
||||
obj = json.loads(candidate)
|
||||
# Optionally, ensure it's a dictionary if that's what you expect
|
||||
if isinstance(obj, dict):
|
||||
results.append(obj)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON, skip it.
|
||||
pass
|
||||
else:
|
||||
i += 1
|
||||
return results
|
||||
|
||||
|
||||
# Tool definition for search corpus
|
||||
SEARCH_TOOL_DEFINITION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_corpus",
|
||||
"description": "Search over the knowledge corpus with a given query",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search the knowledge corpus with",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function."""
|
||||
parser = argparse.ArgumentParser(description="DeepSearch CLI")
|
||||
parser.add_argument(
|
||||
"--lora_path",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Path to LoRA weights (None for base model, 'auto' for auto-detection)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Sampling temperature (default: 0.7)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="System prompt to guide model behavior",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Auto-detect LoRA path if requested
|
||||
lora_path = None
|
||||
if args.lora_path and args.lora_path.lower() != "none":
|
||||
if args.lora_path == "auto":
|
||||
detected_path = find_latest_checkpoint()
|
||||
if detected_path:
|
||||
lora_path = detected_path
|
||||
logger.info(f"Auto-detected LoRA path: {lora_path}")
|
||||
else:
|
||||
logger.warning("No LoRA checkpoint found. Using base model.")
|
||||
else:
|
||||
lora_path = args.lora_path
|
||||
|
||||
# Initialize and run the CLI
|
||||
cli = DeepSearchCLI(
|
||||
lora_path=lora_path,
|
||||
temperature=args.temperature,
|
||||
system_prompt=args.system_prompt,
|
||||
)
|
||||
cli.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure the vectorstore is loaded
|
||||
if load_vectorstore() is None:
|
||||
logger.warning("FAISS vectorstore could not be loaded. Search functionality may not work.")
|
||||
|
||||
main()
|
Loading…
Reference in new issue