diff --git a/scripts/save_merged_16bit.py b/scripts/save_merged_16bit.py new file mode 100644 index 0000000..0449809 --- /dev/null +++ b/scripts/save_merged_16bit.py @@ -0,0 +1,73 @@ +""" +Simple script to load unsloth checkpoint and save to FP16 format. +""" + +import os + +from unsloth import FastLanguageModel + + +def load_model( + model_name: str, + max_seq_length: int = 8192, + load_in_4bit: bool = True, + fast_inference: bool = True, + max_lora_rank: int = 64, + gpu_memory_utilization: float = 0.6, +): + """Load model and tokenizer with unsloth.""" + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=max_seq_length, + load_in_4bit=load_in_4bit, + fast_inference=fast_inference, + max_lora_rank=max_lora_rank, + gpu_memory_utilization=gpu_memory_utilization, + ) + return model, tokenizer + + +def save_to_fp16(checkpoint_dir: str, output_dir: str | None = None): + """ + Load unsloth checkpoint and save to FP16 format. + + Args: + checkpoint_dir: Directory containing the checkpoint + output_dir: Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir) + """ + if output_dir is None: + # Get parent directory of checkpoint and create model_merged_16bit there + parent_dir = os.path.dirname(checkpoint_dir) + output_dir = os.path.join(parent_dir, "model_merged_16bit") + + # Load model and tokenizer + print(f"Loading model from {checkpoint_dir}") + model, tokenizer = load_model(checkpoint_dir) + + # Save to FP16 + print(f"Saving model to FP16 in {output_dir}") + model.save_pretrained_merged( + output_dir, + tokenizer, + save_method="merged_16bit", + ) + print("Done!") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Save unsloth checkpoint to FP16") + parser.add_argument( + "checkpoint_dir", + nargs="?", + default="trainer_output_example/checkpoint-101", + help="Directory containing the checkpoint (default: trainer_output_example/checkpoint-101)", + ) + parser.add_argument( + "--output_dir", + help="Directory to save the FP16 model (default: model_merged_16bit in parent of checkpoint_dir)", + ) + args = parser.parse_args() + + save_to_fp16(args.checkpoint_dir, args.output_dir) diff --git a/train_grpo.py b/train_grpo.py index bd9ee63..e29fdae 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -120,3 +120,13 @@ if __name__ == "__main__": trainer.train() logger.info("Training completed") logger.info(f"Model saved to {OUTPUT_DIR}") + + # Save model to FP16 format + logger.info("Saving model to FP16 format") + model_merged_dir = os.path.join(OUTPUT_DIR, "model_merged_16bit") + model.save_pretrained_merged( + model_merged_dir, + tokenizer, + save_method="merged_16bit", + ) + logger.info(f"FP16 model saved to {model_merged_dir}")