import%20marimo%0A%0A__generated_with%20%3D%20%220.10.12%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20%23%20Launch%20a%20Llama%203%20fine-tuning%20task%20with%20marimo%20and%20Runhouse%0A%20%20%20%20%23%20In%20this%20example%2C%20you%20will%20launch%20a%20remote%20cluster%2C%20write%20your%20training%20in%20a%20marimo%20cell%2C%20and%20then%20send%20your%20training%20code%20to%20that%20remote%20compute%20to%20run.%20This%20way%2C%20you%20can%20access%20unlimited%20power%20and%20scaling%20potential%20while%20still%20having%20a%20great%20dev%20experience.%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20%23%20Setup%20steps%2C%20simply%20importing%20marimo%20and%20Runhouse%20libraries%20and%20setting%20your%20HuggingFace%20Token%20which%20is%20needed%20to%20download%20the%20libraries.%20We%20assume%20by%20opening%20this%2C%20you%20already%20have%20marimo%20installed%3B%20Runhouse%20is%20installed%20with%20the%20additional%20%60aws%60%20though%20you%20could%20do%20%60azure%60%20or%20%60gcp%60%20instead%20%0A%20%20%20%20%23%20First%2C%20pip%20install%20%22runhouse%5Baws%5D%22%20%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20runhouse%20as%20rh%0A%0A%20%20%20%20import%20os%0A%20%20%20%20os.environ%5B%22HF_TOKEN%22%5D%20%3D%20%22My%20Hugging%20Face%20Token%22%20%23%20Used%20to%20download%20Llama3%20weights%20from%20HF%20-%20make%20sure%20you%20sign%20the%20consent%20form%0A%20%20%20%20return%20mo%2C%20os%2C%20rh%0A%0A%0A%40app.cell%0Adef%20_(rh)%3A%0A%20%20%20%20%23%20Launch%20a%20remote%20cluster%20with%20a%20GPU%20-%20here%20we%20use%20an%20L4%20from%20AWS.%20We%20assume%20you%20have%20already%20set%20up%20with%20your%20cloud%20provider%20CLI%20(e.g.%20%60aws%20configure%60%20or%20%60gcloud%20init%60)%20and%20your%20account%20is%20authorized%20to%20launch%20compute.%20%0A%20%20%20%20img%20%3D%20(%0A%20%20%20%20%20%20%20%20rh.Image(name%3D%22llama3finetuning%22)%0A%20%20%20%20%20%20%20%20.install_packages(%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22torch%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22tensorboard%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22transformers%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22bitsandbytes%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22peft%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22trl%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22accelerate%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22scipy%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22marimo%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20.sync_secrets(%5B%22huggingface%22%5D)%0A%20%20%20%20)%0A%0A%20%20%20%20cluster%20%3D%20rh.cluster(%0A%20%20%20%20%20%20%20%20name%3D%22rh-L4%22%2C%20instance_type%3D%22L4%3A1%22%2C%20provider%3D%22aws%22%2C%20image%3Dimg%2C%20autostop_mins%20%3D%20120%0A%20%20%20%20).up_if_not()%0A%20%20%20%20return%20cluster%2C%20img%0A%0A%0A%40app.cell(disabled%3DTrue)%0Adef%20finetune()%3A%0A%20%20%20%20%23%20Define%20the%20fine%20tuning%20in%20this%20cell%20that%20we%20will%20dispatch%20and%20run%20in%20the%20remote%20cluster%20we%20launched.%20This%20cell%20is%20named%20%22finetuner%22%20so%20we%20can%20dispatch%20it%20by%20function%20name%20in%20the%20next%20step%0A%0A%20%20%20%20%23%20Dispatching%20and%20executing%20will%20happen%20in%20the%20next%20cell%20-%20but%20by%20making%20edits%20here%20and%20rerunning%20the%20next%20cell%2C%20your%20iteration%20loops%20should%20feel%20local-like%20(dispatch%20takes%20%3C2s%20to%20remote)%0A%20%20%20%20from%20transformers%20import%20AutoModelForCausalLM%2C%20BitsAndBytesConfig%2C%20AutoTokenizer%2C%20pipeline%0A%20%20%20%20import%20gc%20%0A%20%20%20%20from%20trl%20import%20SFTTrainer%2C%20SFTConfig%0A%20%20%20%20from%20peft%20import%20AutoPeftModelForCausalLM%2C%20LoraConfig%0A%20%20%20%20from%20pathlib%20import%20Path%0A%20%20%20%20from%20accelerate%20import%20PartialState%2C%20Accelerator%0A%20%20%20%20import%20torch%20%0A%20%20%20%20from%20datasets%20import%20load_dataset%0A%0A%20%20%20%20DEFAULT_MAX_LENGTH%20%3D%20200%0A%0A%0A%20%20%20%20DATASET_NAME%20%3D%20%22mlabonne%2Fguanaco-llama2-1k%22%0A%20%20%20%20BASE_MODEL_NAME%20%3D%20%22meta-llama%2FLlama-3.2-3B-Instruct%22%0A%20%20%20%20FINE_TUNED_MODEL_NAME%20%3D%20%22llama-3-3b-enhanced%22%0A%0A%20%20%20%20def%20load_base_model(base_model_name)%3A%0A%20%20%20%20%20%20%20%20quant_config%20%3D%20BitsAndBytesConfig(%0A%20%20%20%20%20%20%20%20%20%20%20%20load_in_4bit%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20bnb_4bit_quant_type%3D%22nf4%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20bnb_4bit_compute_dtype%3Dtorch.bfloat16%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20bnb_4bit_use_double_quant%3DFalse%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20base_model%20%3D%20AutoModelForCausalLM.from_pretrained(%0A%20%20%20%20%20%20%20%20%20%20%20%20base_model_name%2C%20quantization_config%3Dquant_config%2C%20device_map%3D%7B%22%22%3A%200%7D%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20base_model.config.use_cache%20%3D%20False%0A%20%20%20%20%20%20%20%20base_model.config.pretraining_tp%20%3D%201%0A%20%20%20%20%20%20%20%20return%20base_model%0A%0A%20%20%20%20def%20load_tokenizer(base_model_name)%3A%0A%20%20%20%20%20%20%20%20tokenizer%20%3D%20AutoTokenizer.from_pretrained(base_model_name%2C%20trust_remote_code%3DTrue)%0A%20%20%20%20%20%20%20%20tokenizer.pad_token%20%3D%20tokenizer.eos_token%0A%20%20%20%20%20%20%20%20tokenizer.padding_side%20%3D%20%22right%22%0A%20%20%20%20%20%20%20%20return%20tokenizer%0A%0A%20%20%20%20def%20load_pipeline(fine_tuned_model%2C%20tokenizer%2C%20max_length)%3A%0A%20%20%20%20%20%20%20%20return%20pipeline(%0A%20%20%20%20%20%20%20%20%20%20%20%20task%3D%22text-generation%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20model%3Dfine_tuned_model%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20tokenizer%3Dtokenizer%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20max_length%3Dmax_length%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20load_dataset_data(dataset_name)%3A%0A%20%20%20%20%20%20%20%20return%20load_dataset(dataset_name%2C%20split%3D%22train%22)%0A%0A%20%20%20%20def%20training_params()%3A%0A%20%20%20%20%20%20%20%20return%20SFTConfig(%0A%20%20%20%20%20%20%20%20%20%20%20%20output_dir%3D%22.%2Fresults_modified%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20num_train_epochs%3D1%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20per_device_train_batch_size%3D2%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20gradient_accumulation_steps%3D1%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20optim%3D%22paged_adamw_32bit%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20save_steps%3D25%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20logging_steps%3D25%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20learning_rate%3D2e-4%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20weight_decay%3D0.001%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20fp16%3DFalse%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20bf16%3DFalse%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20max_grad_norm%3D0.3%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20max_steps%3D-1%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20warmup_ratio%3D0.03%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20dataset_text_field%3D%22text%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20group_by_length%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20lr_scheduler_type%3D%22constant%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20report_to%3D%22tensorboard%22%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20sft_trainer(base_model%2C%20training_data%2C%20peft_parameters%2C%20tokenizer%2C%20train_params)%3A%0A%20%20%20%20%20%20%20%20return%20SFTTrainer(%0A%20%20%20%20%20%20%20%20%20%20%20%20model%3Dbase_model%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20train_dataset%3Dtraining_data%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20peft_config%3Dpeft_parameters%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20tokenizer%3Dtokenizer%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20args%3Dtrain_params%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20tune(base_model_name%2C%20dataset_name%2C%20fine_tuned_model_name)%3A%0A%20%20%20%20%20%20%20%20gc.collect()%0A%20%20%20%20%20%20%20%20torch.cuda.empty_cache()%0A%20%20%20%20%20%20%20%20gc.collect()%0A%0A%20%20%20%20%20%20%20%20training_data%20%3D%20load_dataset_data(dataset_name)%0A%20%20%20%20%20%20%20%20tokenizer%20%3D%20load_tokenizer(base_model_name)%0A%20%20%20%20%20%20%20%20base_model%20%3D%20load_base_model(base_model_name)%0A%0A%20%20%20%20%20%20%20%20peft_parameters%20%3D%20LoraConfig(%0A%20%20%20%20%20%20%20%20%20%20%20%20lora_alpha%3D16%2C%20lora_dropout%3D0.1%2C%20r%3D8%2C%20bias%3D%22none%22%2C%20task_type%3D%22CAUSAL_LM%22%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20train_params%20%3D%20training_params()%0A%20%20%20%20%20%20%20%20trainer%20%3D%20sft_trainer(base_model%2C%20training_data%2C%20peft_parameters%2C%20tokenizer%2C%20train_params)%0A%0A%20%20%20%20%20%20%20%20trainer.train()%0A%0A%20%20%20%20%20%20%20%20trainer.model.save_pretrained(fine_tuned_model_name)%0A%20%20%20%20%20%20%20%20trainer.tokenizer.save_pretrained(fine_tuned_model_name)%0A%20%20%20%20%20%20%20%20print(%22Saved%20model%20weights%20and%20tokenizer%20on%20the%20cluster.%22)%0A%0A%20%20%20%20def%20load_fine_tuned_model(fine_tuned_model_name)%3A%0A%20%20%20%20%20%20%20%20if%20not%20Path(f%22~%2F%7Bfine_tuned_model_name%7D%22).expanduser().exists()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20raise%20FileNotFoundError(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22No%20fine-tuned%20model%20found%20on%20the%20cluster.%20Call%20the%20%60tune%60%20method%20to%20run%20the%20fine-tuning.%22%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20fine_tuned_model%20%3D%20AutoPeftModelForCausalLM.from_pretrained(%0A%20%20%20%20%20%20%20%20%20%20%20%20fine_tuned_model_name%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20device_map%3D%7B%22%22%3A%20%22cuda%3A0%22%7D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20torch_dtype%3Dtorch.bfloat16%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20return%20fine_tuned_model.merge_and_unload()%0A%0A%20%20%20%20def%20generate(query%2C%20fine_tuned_model_name%2C%20max_length%3DDEFAULT_MAX_LENGTH)%3A%0A%20%20%20%20%20%20%20%20fine_tuned_model%20%3D%20load_fine_tuned_model(fine_tuned_model_name)%0A%20%20%20%20%20%20%20%20tokenizer%20%3D%20load_tokenizer(BASE_MODEL_NAME)%0A%20%20%20%20%20%20%20%20gen_pipeline%20%3D%20load_pipeline(fine_tuned_model%2C%20tokenizer%2C%20max_length)%0A%0A%20%20%20%20%20%20%20%20output%20%3D%20gen_pipeline(f%22%3Cs%3E%5BINST%5D%20%7Bquery%7D%20%5B%2FINST%5D%22)%0A%20%20%20%20%20%20%20%20return%20output%5B0%5D%5B%22generated_text%22%5D%0A%0A%20%20%20%20tune(BASE_MODEL_NAME%2C%20DATASET_NAME%2C%20FINE_TUNED_MODEL_NAME)%0A%20%20%20%20print(generate(%22What%20is%20the%20capital%20of%20France%3F%22%2C%20FINE_TUNED_MODEL_NAME))%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20Accelerator%2C%0A%20%20%20%20%20%20%20%20AutoModelForCausalLM%2C%0A%20%20%20%20%20%20%20%20AutoPeftModelForCausalLM%2C%0A%20%20%20%20%20%20%20%20AutoTokenizer%2C%0A%20%20%20%20%20%20%20%20BASE_MODEL_NAME%2C%0A%20%20%20%20%20%20%20%20BitsAndBytesConfig%2C%0A%20%20%20%20%20%20%20%20DATASET_NAME%2C%0A%20%20%20%20%20%20%20%20DEFAULT_MAX_LENGTH%2C%0A%20%20%20%20%20%20%20%20FINE_TUNED_MODEL_NAME%2C%0A%20%20%20%20%20%20%20%20LoraConfig%2C%0A%20%20%20%20%20%20%20%20PartialState%2C%0A%20%20%20%20%20%20%20%20Path%2C%0A%20%20%20%20%20%20%20%20SFTConfig%2C%0A%20%20%20%20%20%20%20%20SFTTrainer%2C%0A%20%20%20%20%20%20%20%20gc%2C%0A%20%20%20%20%20%20%20%20generate%2C%0A%20%20%20%20%20%20%20%20load_base_model%2C%0A%20%20%20%20%20%20%20%20load_dataset%2C%0A%20%20%20%20%20%20%20%20load_dataset_data%2C%0A%20%20%20%20%20%20%20%20load_fine_tuned_model%2C%0A%20%20%20%20%20%20%20%20load_pipeline%2C%0A%20%20%20%20%20%20%20%20load_tokenizer%2C%0A%20%20%20%20%20%20%20%20pipeline%2C%0A%20%20%20%20%20%20%20%20sft_trainer%2C%0A%20%20%20%20%20%20%20%20torch%2C%0A%20%20%20%20%20%20%20%20training_params%2C%0A%20%20%20%20%20%20%20%20tune%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20%23%20Some%20parameters%20that%20are%20available%20for%20use%20in%20training%0A%20%20%20%20train_params%20%3D%20mo.md(%22%22%22%7Bepochs%7D%20%5Cn%0A%20%20%20%20%7Bper_device_train_batch_size%7D%20%5Cn%0A%20%20%20%20%7Bgradient_accumulation_steps%7D%20%5Cn%0A%20%20%20%20%7Boptimizer%7D%20%5Cn%0A%20%20%20%20%7Blearning_rate%7D%20%5Cn%0A%20%20%20%20%7Boutput_dir%7D%20%5Cn%20%0A%20%20%20%20%7Bdataset_text_field%7D%0A%20%20%20%20%22%22%22%0A%20%20%20%20).batch(%0A%20%20%20%20%20%20%20%20epochs%20%3D%20mo.ui.number(start%3D1%2C%20stop%3D100%2C%20step%3D1%2C%20label%3D'Epochs')%2C%0A%20%20%20%20%20%20%20%20per_device_train_batch_size%20%3D%20mo.ui.number(start%3D1%2C%20stop%3D100%2C%20step%3D1%2C%20label%3D'Batch%20Size')%2C%0A%20%20%20%20%20%20%20%20gradient_accumulation_steps%20%3D%20mo.ui.slider(start%3D1%2C%20stop%3D10%2C%20step%3D1%2C%20label%20%3D%20'Gradient%20Accumulation%20Steps')%2C%0A%20%20%20%20%20%20%20%20optimizer%20%3D%20mo.ui.dropdown(options%3D%5B%22paged_adamw_32bit%22%2C%20%22paged_adamw_8bit%22%5D%2C%20value%3D%22paged_adamw_32bit%22%2C%20label%3D%22Optimizer%22)%2C%0A%20%20%20%20%20%20%20%20learning_rate%20%3D%20mo.ui.number(start%3D0%2C%20stop%3D1%2C%20value%20%3D%202e-4%2C%20label%20%3D%20'Learn%20Rate'%2C%20step%3D0.0001)%2C%0A%20%20%20%20%20%20%20%20output_dir%20%3D%20mo.ui.text(value%3D%22.%2Fresults_modified%22%2C%20label%20%3D%20'Output%20Directory')%2C%0A%20%20%20%20%20%20%20%20dataset_text_field%20%3D%20mo.ui.text(value%3D%22text%22%2C%20label%20%3D%20'Dataset%20Text%20Field')%0A%0A%20%20%20%20)%0A%0A%20%20%20%20train_params%0A%20%20%20%20return%20(train_params%2C)%0A%0A%0A%40app.cell%0Adef%20_(cluster%2C%20mo%2C%20rh)%3A%0A%20%20%20%20%23%20Will%20run%20for%20about%20~25%20min%20to%20download%20and%20tune%20the%20model.%20%0A%20%20%20%20from%20runhouse_marimo%20import%20finetune%20%0A%20%20%20%20fine_tuner_remote%20%3D%20rh.function(finetune).to(cluster%2C%20name%3D%22ft_model%22)%0A%0A%20%20%20%20%23%20For%20cleanliness%2C%20we%20capture%20the%20outputs.%20You%20can%20watch%20the%20training%20happen%20verbosely%20by%20removing%20the%20capture%20of%20the%20standard%20out.%20%0A%20%20%20%20with%20mo.capture_stdout()%20as%20buffer%3A%20%0A%20%20%20%20%20%20%20%20fine_tuner_remote()%0A%20%20%20%20return%20buffer%2C%20fine_tuner_remote%2C%20finetune%0A%0A%0A%40app.cell%0Adef%20_(buffer)%3A%0A%20%20%20%20output%20%3D%20buffer.getvalue()%0A%20%20%20%20output%5B%3A500%5D%20%2B%20%22%5Cn...........%5Cn%5Cn...........%22%20%2B%20output%5B-1200%3A%5D%0A%0A%0A%20%20%20%20return%20(output%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
da4705f900d442e8ab7932bdef235bada3c88130e0ab68520bcbe4beecd01141