You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.2 KiB

"""
Simple script to load unsloth checkpoint and save to FP16 format.
"""
import os
from unsloth import FastLanguageModel
def load_model(
model_name: str,
max_seq_length: int = 8192,
load_in_4bit: bool = True,
fast_inference: bool = True,
max_lora_rank: int = 64,
gpu_memory_utilization: float = 0.6,
):
"""Load model and tokenizer with unsloth."""
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
fast_inference=fast_inference,
max_lora_rank=max_lora_rank,
gpu_memory_utilization=gpu_memory_utilization,
)
return model, tokenizer
def save_to_fp16(checkpoint_dir: str, output_dir: str | None = None):
"""
Load unsloth checkpoint and save to FP16 format.
Args:
checkpoint_dir: Directory containing the checkpoint
output_dir: Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)
"""
if output_dir is None:
# Get parent directory of checkpoint and create model_merged_16bit there
parent_dir = os.path.dirname(checkpoint_dir)
output_dir = os.path.join(parent_dir, "model_merged_16bit")
# Load model and tokenizer
print(f"Loading model from {checkpoint_dir}")
model, tokenizer = load_model(checkpoint_dir)
# Save to FP16
print(f"Saving model to FP16 in {output_dir}")
model.save_pretrained_merged(
output_dir,
tokenizer,
save_method="merged_16bit",
)
print("Done!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Save unsloth checkpoint to FP16")
parser.add_argument(
"checkpoint_dir",
nargs="?",
default="trainer_output_example/checkpoint-101",
help="Directory containing the checkpoint (default: trainer_output_example/checkpoint-101)",
)
parser.add_argument(
"--output_dir",
help="Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)",
)
args = parser.parse_args()
save_to_fp16(args.checkpoint_dir, args.output_dir)