import%20marimo%0A%0A__generated_with%20%3D%20%220.8.0%22%0Aapp%20%3D%20marimo.App(app_title%3D%22IOAI%202024%20CV%20On-Site%22)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20return%20mo%2C%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%20IOAI%202024%3A%20CV%20On-Site%20Task%20Solution%0A%0A%20%20%20%20%20%20%20%20This%20is%20the%20solution%20my%20team%20submitted%20for%20the%20computer%20vision%20task%20of%20the%202024%20International%20Olympiad%20in%20AI.%20It%20achieved%20a%20min-max%20normalised%20score%20of%2099%25%20ranking%203rd%2F43%20teams.%20Only%2011%20teams%20received%20a%20score%20higher%20than%20baseline%20and%20we%20were%201%20of%204%20teams%20that%20produced%20a%20near-optimal%20solution.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.image(%0A%20%20%20%20%20%20%20%20src%3D%22https%3A%2F%2Fwww.glennwu.com%2Fimages%2Fmarimo%2Fcv_banner.png%22%2C%0A%20%20%20%20%20%20%20%20alt%3D%22Cow%20and%20fire%20hydrant%20at%20sunset%22%2C%0A%20%20%20%20%20%20%20%20rounded%3DTrue%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%20The%20Problem%3A%20%22The%20Madarian%20Cow%20Mystery%22%0A%0A%20%20%20%20%20%20%20%20%23%23%23%20Story%0A%20%20%20%20%20%20%20%20Following%20your%20successful%20adaptation%20of%20the%20image%20generation%20AI%20to%20accommodate%20the%20Madarian%20language%20quirk%20regarding%20zebras%20and%20giraffes%2C%20your%20team%20has%20made%20significant%20progress%20in%20fostering%20communication%20and%20cultural%20exchange%20with%20the%20inhabitants%20of%20Madaria.%20Your%20efforts%20have%20not%20gone%20unnoticed%2C%20and%20you've%20been%20entrusted%20with%20a%20new%20challenge.%0A%0A%20%20%20%20%20%20%20%20During%20a%20routine%20survey%20of%20Madarian%20farmlands%2C%20your%20team%20stumbles%20upon%20a%20peculiar%20sight.%20What%20appears%20to%20be%20a%20standard%20Earth%20fire%20hydrant%20stands%20proudly%20in%20the%20middle%20of%20a%20field%2C%20surrounded%20by%20cows.%20Upon%20closer%20inspection%2C%20you%20realize%20that%20these%20fire%20hydrants%20are%20indeed%20identical%20to%20those%20on%20Earth%2C%20but%20their%20purpose%20and%20significance%20on%20Madaria%20are%20entirely%20different.%0A%0A%20%20%20%20%20%20%20%20The%20Madarians%20have%20developed%20a%20deep%20cultural%20and%20spiritual%20connection%20to%20these%20fire%20hydrants%2C%20considering%20them%20sacred%20guardians%20of%20their%20livestock.%20They%20believe%20that%20the%20presence%20of%20these%20hydrants%20ensures%20the%20health%20and%20prosperity%20of%20their%20cow%20herds.%20As%20a%20result%2C%20Madarian%20farmers%20always%20expect%20to%20see%20a%20fire%20hydrant%20in%20any%20depiction%20or%20image%20of%20their%20cattle.%0A%0A%0A%20%20%20%20%20%20%20%20The%20sensitivity%20of%20the%20situation%20pushes%20you%20to%20make%20changes%20fast%2C%20so%20you%20won't%20be%20retraining%20the%20full%20model%2C%20just%20a%20modifier%20for%20the%20initial%20embeddings%20and%20latent%20representations.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%20Your%20Mission%0A%0A%20%20%20%20%20%20%20%20Modify%20your%20image%20generation%20AI%20to%20automatically%20include%20a%20fire%20hydrant%20in%20any%20image%20where%20a%20cow%20is%20expected.%20This%20will%20align%20with%20Madarian%20expectations%20and%20cultural%20norms.%0A%20%20%20%20%20%20%20%20Ensure%20that%20the%20AI%20does%20not%20include%20fire%20hydrants%20when%20generating%20images%20of%20other%20animals%2C%20maintaining%20accuracy%20for%20all%20other%20fauna.%20No%20need%20to%20switch%20zebra%2Fgiraffe.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%20Formal%20Task%0A%0A%20%20%20%20%20%20%20%20-%20Draw%20a%20fire%20hydrant%20in%20the%20image%20when%20the%20prompt%20requires%20drawing%20a%20cow.%0A%20%20%20%20%20%20%20%20-%20Don't%20draw%20a%20fire%20hydrant%20in%20other%20images.%20There%20will%20be%20no%20direct%20'fire%20hydrant'%20prompts%20in%20the%20test.%0A%20%20%20%20%20%20%20%20-%20You%20will%20use%20the%20familiar%20to%20you%20%60miniSD-diffusers%60%20model%20for%20inference%2C%20but%20you%20will%20only%20be%20able%20to%20modify%20text%20embeddings%20and%20initial%20latent%20representations.%0A%20%20%20%20%20%20%20%20-%20Please%20make%20sure%20you%20don't%20use%20any%20external%20data%20except%20the%20provided%20dataset%20and%20don't%20add%20more%20arguments%20to%20magic%20modifier%20function.%20The%20solution%20will%20**not**%20be%20scored%20otherwise.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20).callout()%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%20Deliverables%0A%20%20%20%20%20%20%20%20-%20This%20notebook%20with%20code%20that%20reproduces%20your%20solution%0A%20%20%20%20%20%20%20%20-%20Prediction%20on%20embeddings%20that%20would%20be%20provided%20to%20you%20during%20the%20last%20hour%20of%20the%20competition%2C%20as%20a%20%60predictions.json%60%20file%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__()%3A%0A%20%20%20%20%23%20import%20importlib%0A%0A%20%20%20%20%23%20if%20importlib.util.find_spec('diffusers')%20is%20None%3A%0A%20%20%20%20%23%20%20%20%20%20!pip%20install%20torch%3D%3D2.2.1%20transformers%3D%3D4.39.1%20diffusers%3D%3D0.27.2%20torchvision%3D%3D0.17.1%20datasets%3D%3D2.18.0%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%23%20Solution%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20torch.nn%20as%20nn%0A%20%20%20%20import%20torch.nn.functional%20as%20F%0A%0A%20%20%20%20from%20torch.utils.data%20import%20WeightedRandomSampler%2C%20DataLoader%0A%0A%20%20%20%20from%20torchvision%20import%20transforms%0A%0A%20%20%20%20from%20datasets%20import%20load_dataset%2C%20DatasetDict%2C%20Dataset%0A%20%20%20%20from%20diffusers.pipelines.pipeline_utils%20import%20DiffusionPipeline%0A%0A%20%20%20%20from%20huggingface_hub%20import%20PyTorchModelHubMixin%0A%20%20%20%20from%20transformers%20import%20DetrImageProcessor%2C%20DetrForObjectDetection%0A%20%20%20%20import%20pandas%20as%20pd%0A%0A%20%20%20%20from%20tqdm.auto%20import%20tqdm%0A%20%20%20%20from%20tqdm.notebook%20import%20tqdm_notebook%0A%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20math%0A%0A%20%20%20%20import%20re%0A%20%20%20%20import%20random%0A%20%20%20%20import%20copy%0A%20%20%20%20import%20time%0A%20%20%20%20import%20os%0A%20%20%20%20import%20pickle%0A%20%20%20%20import%20json%0A%0A%20%20%20%20import%20gc%0A%0A%20%20%20%20from%20PIL%20import%20Image%0A%0A%20%20%20%20def%20clear_cuda()%3A%0A%20%20%20%20%20%20%20%20torch.cuda.empty_cache()%0A%20%20%20%20%20%20%20%20gc.collect()%0A%0A%20%20%20%20torch.manual_seed(42)%0A%20%20%20%20random.seed(42)%0A%20%20%20%20np.random.seed(42)%0A%0A%20%20%20%20device%20%3D%20%22cuda%22%0A%20%20%20%20base_model_name%20%3D%20%22InternationalOlympiadAI%2FminiSD-diffusers%22%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%23%20Imports%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20DataLoader%2C%0A%20%20%20%20%20%20%20%20Dataset%2C%0A%20%20%20%20%20%20%20%20DatasetDict%2C%0A%20%20%20%20%20%20%20%20DetrForObjectDetection%2C%0A%20%20%20%20%20%20%20%20DetrImageProcessor%2C%0A%20%20%20%20%20%20%20%20DiffusionPipeline%2C%0A%20%20%20%20%20%20%20%20F%2C%0A%20%20%20%20%20%20%20%20Image%2C%0A%20%20%20%20%20%20%20%20PyTorchModelHubMixin%2C%0A%20%20%20%20%20%20%20%20WeightedRandomSampler%2C%0A%20%20%20%20%20%20%20%20base_model_name%2C%0A%20%20%20%20%20%20%20%20clear_cuda%2C%0A%20%20%20%20%20%20%20%20copy%2C%0A%20%20%20%20%20%20%20%20device%2C%0A%20%20%20%20%20%20%20%20gc%2C%0A%20%20%20%20%20%20%20%20json%2C%0A%20%20%20%20%20%20%20%20load_dataset%2C%0A%20%20%20%20%20%20%20%20math%2C%0A%20%20%20%20%20%20%20%20nn%2C%0A%20%20%20%20%20%20%20%20np%2C%0A%20%20%20%20%20%20%20%20os%2C%0A%20%20%20%20%20%20%20%20pd%2C%0A%20%20%20%20%20%20%20%20pickle%2C%0A%20%20%20%20%20%20%20%20random%2C%0A%20%20%20%20%20%20%20%20re%2C%0A%20%20%20%20%20%20%20%20time%2C%0A%20%20%20%20%20%20%20%20torch%2C%0A%20%20%20%20%20%20%20%20tqdm%2C%0A%20%20%20%20%20%20%20%20tqdm_notebook%2C%0A%20%20%20%20%20%20%20%20transforms%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__()%3A%0A%20%20%20%20def%20onsite_remap(example)%3A%0A%20%20%20%20%20%20%20%20text%20%3D%20example%5B'text'%5D%0A%20%20%20%20%20%20%20%20example%5B%22orig_text%22%5D%20%3D%20text%0A%20%20%20%20%20%20%20%20if%20example%5B'type'%5D%20%3D%3D%20%22cow%22%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20text%20%3D%20f'%7Btext%7D%20with%20a%20red%20fire%20hydrant'%0A%20%20%20%20%20%20%20%20elif%20example%5B'text'%5D%20%3D%3D%20'hydrant'%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20example%5B%22orig_text%22%5D%20%3D%20text.replace('%20hydrant'%2C%20'')%0A%20%20%20%20%20%20%20%20example%5B'text'%5D%20%3D%20text%0A%20%20%20%20%20%20%20%20return%20example%0A%20%20%20%20return%20onsite_remap%2C%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__()%3A%0A%20%20%20%20def%20get_type(example)%3A%0A%20%20%20%20%20%20%20%20text%20%3D%20example%5B'text'%5D.lower()%0A%20%20%20%20%20%20%20%20img_type%20%3D%20%22random%22%0A%20%20%20%20%20%20%20%20if%20%22hydrant%22%20in%20text%20and%20%22cow%22%20in%20text%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20img_type%20%3D%20%22hydrant%22%0A%20%20%20%20%20%20%20%20elif%20%22cow%22%20in%20text%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20img_type%20%3D%20%22cow%22%0A%20%20%20%20%20%20%20%20example%5B%22type%22%5D%20%3D%20img_type%0A%20%20%20%20%20%20%20%20return%20example%0A%20%20%20%20return%20get_type%2C%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(%0A%20%20%20%20DataLoader%2C%0A%20%20%20%20Dataset%2C%0A%20%20%20%20DatasetDict%2C%0A%20%20%20%20WeightedRandomSampler%2C%0A%20%20%20%20get_type%2C%0A%20%20%20%20load_dataset%2C%0A%20%20%20%20math%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20np%2C%0A%20%20%20%20onsite_remap%2C%0A%20%20%20%20torch%2C%0A%20%20%20%20transforms%2C%0A)%3A%0A%20%20%20%20def%20calculate_class_weights(dataset)%3A%0A%20%20%20%20%20%20%20%20%23%20class_counts%20%3D%20%7B'giraffe'%3A%200%2C%20'zebra'%3A%200%2C%20'random'%3A%200%7D%0A%20%20%20%20%20%20%20%20classes%2C%20counts%20%3D%20np.unique(dataset%5B'type'%5D%2C%20return_counts%3DTrue)%0A%20%20%20%20%20%20%20%20class_counts%20%3D%20%7Bcls%3A%20count%20for%20cls%2C%20count%20in%20zip(classes%2C%20counts)%7D%0A%20%20%20%20%20%20%20%20total_samples%20%3D%20sum(class_counts.values())%0A%20%20%20%20%20%20%20%20return%20%7Bcls%3A%20math.sqrt(total_samples%20%2F%20count)%20for%20cls%2C%20count%20in%20class_counts.items()%7D%0A%0A%20%20%20%20def%20get_dataloader(dataset%2C%20batch_size%2C%20collate_fn%3DNone%2C%20sample_weights%3DNone)%3A%0A%20%20%20%20%20%20%20%20dataloader_args%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22batch_size%22%3A%20batch_size%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22num_workers%22%3A%200%0A%20%20%20%20%20%20%20%20%7D%0A%20%20%20%20%20%20%20%20if%20sample_weights%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dataloader_args%5B%22sampler%22%5D%20%3D%20WeightedRandomSampler(sample_weights%2C%20num_samples%3Dlen(dataset)%2C%20replacement%3DTrue)%0A%0A%20%20%20%20%20%20%20%20if%20collate_fn%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dataloader_args%5B%22collate_fn%22%5D%20%3D%20collate_fn%0A%0A%20%20%20%20%20%20%20%20return%20DataLoader(%0A%20%20%20%20%20%20%20%20%20%20%20%20dataset%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20**dataloader_args%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20load_dataset_with_token(dataset_url%2C%20token)%3A%0A%20%20%20%20%20%20%20%20if%20len(token)%20%3E%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dataset%20%3D%20load_dataset(dataset_url%2C%20token%3Dtoken)%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dataset%20%3D%20load_dataset(dataset_url)%0A%20%20%20%20%20%20%20%20return%20dataset%0A%0A%0A%20%20%20%20def%20get_dataset(tokenizer%2C%20resolution%3D256%2C%20dataset_url%3D'InternationalOlympiadAI%2FCV_problem_onsite'%2C%20text_column%3D'sentence'%2C%20token%3D%22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%22%2C%20remap_fn%3Donsite_remap%2C%20description_proba%3D0.2)%20-%3E%20Dataset%3A%0A%20%20%20%20%20%20%20%20dataset%20%3D%20load_dataset_with_token(dataset_url%2C%20token%3Dtoken)%0A%20%20%20%20%20%20%20%20if%20not(len(text_column)%20%3D%3D%200%20or%20text_column%20%3D%3D%20%22text%22)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20column_names%20%3D%20dataset.column_names%0A%20%20%20%20%20%20%20%20%20%20%20%20text_present%20%3D%20False%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20isinstance(column_names%2C%20dict)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20text_present%20%3D%20'text'%20in%20list(dataset.column_names.values())%5B0%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20text_present%20%3D%20'text'%20in%20dataset.column_names%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20text_present%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20dataset%20%3D%20dataset.remove_columns(%5B'text'%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20dataset%20%3D%20dataset.rename_column(text_column%2C%20%22text%22)%0A%0A%20%20%20%20%20%20%20%20dataset%20%3D%20dataset.map(get_type)%0A%0A%20%20%20%20%20%20%20%20if%20remap_fn%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dataset%20%3D%20dataset.map(remap_fn)%0A%0A%20%20%20%20%20%20%20%20train_transforms%20%3D%20transforms.Compose(%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.Resize(resolution%2C%20interpolation%3Dtransforms.InterpolationMode.BILINEAR)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.CenterCrop(resolution)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.RandomHorizontalFlip()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.ToTensor()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.Normalize(%5B0.5%5D%2C%20%5B0.5%5D)%2C%0A%20%20%20%20%20%20%20%20%5D)%0A%0A%20%20%20%20%20%20%20%20def%20tokenize(examples)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20%22text%22%20in%20examples%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20text%20%3D%20examples%5B%22text%22%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20orig_text%20%3D%20examples%5B%22orig_text%22%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20examples%5B%22orig_input_ids%22%5D%20%3D%20tokenizer(orig_text%2C%20max_length%3Dtokenizer.model_max_length%2C%20padding%3D%22max_length%22%2C%20truncation%3DTrue%2C%20return_tensors%3D%22pt%22).input_ids%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20examples%5B%22input_ids%22%5D%20%3D%20tokenizer(text%2C%20max_length%3Dtokenizer.model_max_length%2C%20padding%3D%22max_length%22%2C%20truncation%3DTrue%2C%20return_tensors%3D%22pt%22).input_ids%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20examples%0A%0A%20%20%20%20%20%20%20%20def%20preprocess_get(examples)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20examples%20%3D%20tokenize(examples)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20examples%0A%0A%20%20%20%20%20%20%20%20if%20isinstance(dataset%2C%20DatasetDict)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dataset%20%3D%20dataset%5B'train'%5D%0A%20%20%20%20%20%20%20%20if%20not%20isinstance(dataset%2C%20Dataset)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20raise%20TypeError(f%22Dataset%20should%20not%20be%20%7Btype(dataset)%7D%22)%0A%0A%20%20%20%20%20%20%20%20return%20dataset.with_transform(preprocess_get)%0A%0A%20%20%20%20def%20collate_fn(examples)%3A%0A%20%20%20%20%20%20%20%20input_ids%20%3D%20torch.stack(%5Bexample%5B%22input_ids%22%5D%20for%20example%20in%20examples%5D)%0A%20%20%20%20%20%20%20%20orig_input_ids%20%3D%20torch.stack(%5Bexample%5B%22orig_input_ids%22%5D%20for%20example%20in%20examples%5D)%0A%20%20%20%20%20%20%20%20classes%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22cow%22%3A%200%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22hydrant%22%3A%200%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22random%22%3A%201%0A%20%20%20%20%20%20%20%20%7D%0A%20%20%20%20%20%20%20%20labels%20%3D%20torch.tensor(%5Bclasses%5Bexample%5B'type'%5D%5D%20for%20example%20in%20examples%5D)%0A%20%20%20%20%20%20%20%20return%20%7B%22input_ids%22%3A%20input_ids%2C%20%22orig_input_ids%22%3A%20orig_input_ids%2C%20%22labels%22%3A%20labels%7D%0A%0A%20%20%20%20def%20get_sample_weights(dataset%2C%20class_weights)%3A%0A%20%20%20%20%20%20%20%20samples_classes%20%3D%20dataset%5B'type'%5D%0A%20%20%20%20%20%20%20%20sample_weights%20%3D%20%5Bclass_weights%5Bcls%5D%20for%20cls%20in%20samples_classes%5D%0A%20%20%20%20%20%20%20%20return%20sample_weights%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%23%20Defining%20utility%20functions%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20calculate_class_weights%2C%0A%20%20%20%20%20%20%20%20collate_fn%2C%0A%20%20%20%20%20%20%20%20get_dataloader%2C%0A%20%20%20%20%20%20%20%20get_dataset%2C%0A%20%20%20%20%20%20%20%20get_sample_weights%2C%0A%20%20%20%20%20%20%20%20load_dataset_with_token%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(DiffusionPipeline%2C%20base_model_name%2C%20mo)%3A%0A%20%20%20%20def%20get_pipe()%3A%0A%20%20%20%20%20%20%20%20device%20%3D%20%22cuda%22%0A%0A%20%20%20%20%20%20%20%20pipe%20%3D%20DiffusionPipeline.from_pretrained(base_model_name)%0A%20%20%20%20%20%20%20%20pipe.to(device)%0A%0A%20%20%20%20%20%20%20%20tokenizer%20%3D%20pipe.tokenizer%0A%0A%20%20%20%20%20%20%20%20return%20pipe%2C%20tokenizer%0A%0A%20%20%20%20pipe%2C%20tokenizer%20%3D%20get_pipe()%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%23%20Loading%20the%20base%20model%20pipeline%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%20get_pipe%2C%20pipe%2C%20tokenizer%0A%0A%0A%40app.cell%0Adef%20__(%0A%20%20%20%20calculate_class_weights%2C%0A%20%20%20%20collate_fn%2C%0A%20%20%20%20get_dataloader%2C%0A%20%20%20%20get_dataset%2C%0A%20%20%20%20get_sample_weights%2C%0A%20%20%20%20tokenizer%2C%0A)%3A%0A%20%20%20%20ds%20%3D%20get_dataset(tokenizer)%0A%0A%20%20%20%20class_weights%20%3D%20calculate_class_weights(ds)%0A%20%20%20%20sample_weights%20%3D%20get_sample_weights(ds%2C%20class_weights)%0A%0A%20%20%20%20ds.shuffle()%0A%0A%20%20%20%20train_dataloader%20%3D%20get_dataloader(ds%2C%2032%2C%20collate_fn%3Dcollate_fn%2C%20sample_weights%3Dsample_weights)%0A%20%20%20%20return%20class_weights%2C%20ds%2C%20sample_weights%2C%20train_dataloader%0A%0A%0A%40app.cell%0Adef%20__(transforms)%3A%0A%20%20%20%20resolution%20%3D%20512%0A%0A%20%20%20%20vae_transforms%20%3D%20transforms.Compose(%5B%0A%20%20%20%20%20%20%20%20transforms.Resize(resolution%2C%20interpolation%3Dtransforms.InterpolationMode.BILINEAR)%2C%0A%20%20%20%20%20%20%20%20transforms.CenterCrop(resolution)%2C%0A%20%20%20%20%20%20%20%20transforms.RandomHorizontalFlip()%2C%0A%20%20%20%20%20%20%20%20transforms.ToTensor()%2C%0A%20%20%20%20%20%20%20%20transforms.Normalize(%5B0.5%5D%2C%20%5B0.5%5D)%2C%0A%20%20%20%20%5D)%0A%0A%20%20%20%20def%20get_pixel_values(examples)%3A%0A%20%20%20%20%20%20%20%20if%20%22image%22%20in%20examples%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20images%20%3D%20%5Bimage.convert(%22RGB%22)%20for%20image%20in%20examples%5B'image'%5D%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20examples%5B%22pixel_values%22%5D%20%3D%20%5Bvae_transforms(image)%20for%20image%20in%20images%5D%0A%20%20%20%20%20%20%20%20return%20examples%0A%20%20%20%20return%20get_pixel_values%2C%20resolution%2C%20vae_transforms%0A%0A%0A%40app.cell%0Adef%20__(F%2C%20mo%2C%20nn%2C%20torch%2C%20vae_transforms)%3A%0A%20%20%20%20class%20Magic(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20num_blocks%2C%20vae%2C%20image)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.latent%20%3D%20vae.encode(vae_transforms(image.convert(%22RGB%22)).unsqueeze(0).to(vae.device))%0A%20%20%20%20%20%20%20%20%20%20%20%20noise%20%3D%20torch.randn_like(self.latent.latent_dist.sample())%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.num_blocks%20%3D%20num_blocks%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.fcs%20%3D%20nn.ModuleList(%5Bnn.Linear(768%2C%20768)%5D%20*%20num_blocks)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.norms%20%3D%20nn.ModuleList(%5Bnn.LayerNorm(768)%5D%20*%20num_blocks)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dropout%20%3D%20nn.Dropout(0.1)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.fc1%20%3D%20nn.Linear(768%2C%20768)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.cls%20%3D%20nn.Linear(768%2C%202)%0A%0A%20%20%20%20%20%20%20%20def%20get_perturbed(self%2C%20noise)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self.latent.latent_dist.sample()%20*%200.7%20%2B%20noise%20*%200.3%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20latents%2C%20text_embeddings_mean)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20self.classify_embeddings(text_embeddings_mean)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20pred%20%3D%20logits.max(dim%3D0).indices.item()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20pred%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20latents%2C%20text_embeddings_mean%20%3D%20self.transform_embeddings(latents%2C%20text_embeddings_mean)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20latents%20%3D%20self.get_perturbed(latents)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20text_embeddings_mean%20*%3D%201%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20latents%2C%20text_embeddings_mean%0A%0A%20%20%20%20%20%20%20%20def%20transform_embeddings(self%2C%20latents%2C%20text_embeddings_mean)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20text_embeddings_mean%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20i%20in%20range(self.num_blocks)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20res%20%3D%20x%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.norms%5Bi%5D(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.fcs%5Bi%5D(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20F.relu(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20res%20%2B%20x%0A%20%20%20%20%20%20%20%20%20%20%20%20text_embeddings_mean%20%3D%20x%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20latents%2C%20text_embeddings_mean%0A%0A%20%20%20%20%20%20%20%20def%20classify_embeddings(self%2C%20text_embeddings_mean)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.fc1(text_embeddings_mean)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.dropout(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20F.relu(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.cls(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20x%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%20Magic%20layer%0A%0A%20%20%20%20%20%20%20%20This%20is%20a%20layer%20that%20takes%20mean%20representation%20for%20text%20and%20latent%20images.%20You%20need%20to%20modify%20these%20representations%20that%20the%20rest%20of%20the%20model%20would%20start%20to%20produce%20hydrants%20with%20cows.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%20Magic%2C%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20slider%20%3D%20mo.ui.slider(start%3D1%2C%20stop%3D10%2C%20label%3D%22No.%20of%20Residual%20Layers%22%2C%20value%3D2%2C%20debounce%3DTrue%2C%20show_value%3DTrue)%0A%20%20%20%20slider%0A%20%20%20%20return%20slider%2C%0A%0A%0A%40app.cell%0Adef%20__(Magic%2C%20device%2C%20ds%2C%20pipe%2C%20slider)%3A%0A%20%20%20%20magic%20%3D%20Magic(slider.value%2C%20pipe.vae%2C%20ds%5B-20%5D%5B'image'%5D)%0A%20%20%20%20magic.to(device)%0A%20%20%20%20return%20magic%2C%0A%0A%0A%40app.cell%0Adef%20__(F%2C%20device%2C%20mo%2C%20pipe%2C%20torch%2C%20train_dataloader)%3A%0A%20%20%20%20def%20train_classifier(magic%2C%20num_epochs%2C%20learning_rate)%3A%0A%0A%20%20%20%20%20%20%20%20pipe.to(device)%0A%0A%20%20%20%20%20%20%20%20text_encoder%20%3D%20pipe.text_encoder%0A%20%20%20%20%20%20%20%20text_encoder.requires_grad_(False)%0A%0A%20%20%20%20%20%20%20%20text_encoder.to(torch.float32)%0A%20%20%20%20%20%20%20%20magic.to(torch.float32)%0A%0A%20%20%20%20%20%20%20%20optimizer%20%3D%20torch.optim.Adam(magic.parameters()%2C%20lr%3Dlearning_rate)%0A%0A%20%20%20%20%20%20%20%20magic.train()%0A%0A%20%20%20%20%20%20%20%20%23%20for%20epoch%20in%20tqdm_notebook(range(num_epochs))%3A%0A%20%20%20%20%20%20%20%20for%20epoch%20in%20range(num_epochs)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20epoch_loss%20%3D%200%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20batch%20in%20train_dataloader%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20orig_input_ids%20%3D%20batch%5B%22orig_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20labels%20%3D%20batch%5B%22labels%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20orig_hidden_state%20%3D%20text_encoder(orig_input_ids.to(device)).last_hidden_state%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20optimizer.zero_grad()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20orig_mean%20%3D%20orig_hidden_state.mean(dim%3D1)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20magic.classify_embeddings(orig_mean)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20F.cross_entropy(logits%2C%20labels.to(device))%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20loss.backward()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20optimizer.step()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20epoch_loss%20%2B%3D%20loss.item()%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%23%20Classifier%20model%0A%20%20%20%20%20%20%20%20Trains%20a%20model%20to%20classify%20whether%20a%20given%20text%20embedding%20mean%20came%20from%20a%20prompt%20that%20included%20%22cow%22%20or%20%22cow%20with%20fire%20hydrant%22%20%3Cbr%3E%0A%20%20%20%20%20%20%20%20Loss%20function%3A%20Cross%20entropy%20%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%20train_classifier%2C%0A%0A%0A%40app.cell%0Adef%20__(F%2C%20device%2C%20mo%2C%20pipe%2C%20torch%2C%20train_dataloader)%3A%0A%20%20%20%20def%20train_embedding_transform(magic%2C%20num_epochs%2C%20learning_rate)%3A%0A%0A%20%20%20%20%20%20%20%20text_encoder%20%3D%20pipe.text_encoder%0A%20%20%20%20%20%20%20%20text_encoder.requires_grad_(False)%0A%0A%20%20%20%20%20%20%20%20text_encoder.to(torch.float32)%0A%20%20%20%20%20%20%20%20magic.to(torch.float32)%0A%0A%20%20%20%20%20%20%20%20optimizer%20%3D%20torch.optim.Adam(magic.parameters()%2C%20lr%3Dlearning_rate)%0A%0A%20%20%20%20%20%20%20%20magic.train()%0A%0A%20%20%20%20%20%20%20%20%23%20for%20epoch%20in%20tqdm_notebook(range(num_epochs))%3A%0A%20%20%20%20%20%20%20%20for%20epoch%20in%20range(num_epochs)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20epoch_loss%20%3D%200%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20batch%20in%20train_dataloader%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20input_ids%20%3D%20batch%5B%22input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20orig_input_ids%20%3D%20batch%5B%22orig_input_ids%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20_%20%3D%20torch.randn((1%2C%204%2C%2064%2C%2064)%2C%20device%3Ddevice)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20new_hidden_state%20%3D%20text_encoder(input_ids.to(device)).last_hidden_state%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20orig_hidden_state%20%3D%20text_encoder(orig_input_ids.to(device)).last_hidden_state%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20optimizer.zero_grad()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20new_mean%20%3D%20new_hidden_state.mean(dim%3D1)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20orig_mean%20%3D%20orig_hidden_state.mean(dim%3D1)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20_%2C%20new_mean_hat%20%3D%20magic.transform_embeddings(_%2C%20orig_mean)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20F.mse_loss(new_mean_hat%2C%20new_mean)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20loss.backward()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20optimizer.step()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20epoch_loss%20%2B%3D%20loss.item()%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%23%23%20Embedding%20mean%20model%0A%20%20%20%20%20%20%20%20Trains%20a%20model%20to%20transform%20the%20original%20%22cow%22%20text%20embedding%20mean%20to%20the%20new%20%22cow%20with%20fire%20hydrant%22%20text%20embedding%20mean%20%3Cbr%3E%0A%20%20%20%20%20%20%20%20Loss%20function%3A%20%24l_2%24%20(Mean%20square%20error)%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%20train_embedding_transform%2C%0A%0A%0A%40app.cell%0Adef%20__(magic%2C%20train_classifier%2C%20train_embedding_transform)%3A%0A%20%20%20%20train_classifier(magic%2C%205%2C%201e-3)%0A%20%20%20%20train_embedding_transform(magic%2C%2020%2C%201e-4)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20__(magic)%3A%0A%20%20%20%20loaded_magic%20%3D%20magic%0A%20%20%20%20return%20loaded_magic%2C%0A%0A%0A%40app.cell%0Adef%20__(load_dataset%2C%20mo)%3A%0A%20%20%20%20train_dataset%20%3D%20load_dataset('InternationalOlympiadAI%2FCV_problem_onsite'%2C%20token%3D%22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%22)%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%20Dataset%0A%0A%20%20%20%20%20%20%20%20We%20provide%20the%20dataset%20to%20work%20on%20a%20task.%0A%20%20%20%20%20%20%20%20This%20dataset%20includes%20all%20the%20classes%20we%20would%20test%20on%2C%20as%20well%20some%20some%20cows%20with%20hydrant%20images%20together.%0A%20%20%20%20%20%20%20%20This%20is%20the%20only%20external%20data%20that%20could%20be%20used.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%20train_dataset%2C%0A%0A%0A%40app.cell%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%3D%3D%3D%3D%20You%20don't%20need%20to%20change%20anything%20below%20this%20line%2C%20just%20run%20as%20is%20%20%3D%3D%3D%3D%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%20Inference%0A%0A%20%20%20%20%20%20%20%20Below%20is%20inference%20function%2C%20no%20need%20to%20make%20any%20changes%20here.%0A%20%20%20%20%20%20%20%20It's%20provided%20to%20showcase%20how%20your%20code%20would%20be%20applied%0A%20%20%20%20%20%20%20%20It%20will%20be%20exactly%20as%20this%20on%20test%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20__(DiffusionPipeline%2C%20Image%2C%20base_model_name%2C%20torch%2C%20tqdm)%3A%0A%20%20%20%20def%20get_inference()%3A%0A%20%20%20%20%20%20%20%20device%20%3D%20'cuda'%0A%20%20%20%20%20%20%20%20pipe%20%3D%20DiffusionPipeline.from_pretrained(base_model_name).to(device)%0A%20%20%20%20%20%20%20%20vae%20%3D%20pipe.vae.requires_grad_(False)%0A%20%20%20%20%20%20%20%20text_encoder%20%3D%20pipe.text_encoder.requires_grad_(False)%0A%20%20%20%20%20%20%20%20tokenizer%20%3D%20pipe.tokenizer%0A%20%20%20%20%20%20%20%20unet%20%3D%20pipe.unet.requires_grad_(False)%0A%20%20%20%20%20%20%20%20scheduler%20%3D%20pipe.scheduler%0A%0A%0A%20%20%20%20%20%20%20%20def%20custom_inference(prompt%2C%20magic_layer%2C%20num_inference_steps%3D50%2C%20guidance_scale%3D8.5)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20scheduler.set_timesteps(num_inference_steps)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20text_inputs%20%3D%20tokenizer(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20prompt%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20padding%3D%22max_length%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20max_length%3Dtokenizer.model_max_length%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20truncation%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_tensors%3D%22pt%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20).to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20text_embeddings%20%3D%20text_encoder(text_inputs.input_ids)%5B0%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20original_text_mean%20%3D%20text_embeddings.mean(dim%3D1)%5B0%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20original_latents%20%3D%20torch.randn((1%2C%204%2C%2064%2C%2064)%2C%20device%3Ddevice)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Your%20code%20will%20be%20applied%20here.%20All%20the%20other%20code%20is%20a%20standard%20diffusion%20inference%0A%20%20%20%20%20%20%20%20%20%20%20%20latents%2C%20new_text_mean%20%3D%20magic_layer(original_latents%2C%20original_text_mean)%0A%20%20%20%20%20%20%20%20%20%20%20%20text_embeddings%20%3D%20text_embeddings%20%2B%20new_text_mean%20-%20original_text_mean%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Prepare%20unconditional%20input%20for%20classifier%20free%20guidance%0A%20%20%20%20%20%20%20%20%20%20%20%20unconditional_input%20%3D%20tokenizer(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20padding%3D%22max_length%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20max_length%3Dtokenizer.model_max_length%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_tensors%3D%22pt%22%0A%20%20%20%20%20%20%20%20%20%20%20%20).to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20unconditional_embeddings%20%3D%20text_encoder(unconditional_input.input_ids)%5B0%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20combined_text_embeddings%20%3D%20torch.cat(%5Bunconditional_embeddings%2C%20text_embeddings%5D)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Denoising%20loop%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20t%20in%20tqdm(scheduler.timesteps)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20latent_model_input%20%3D%20torch.cat(%5Blatents%5D%20*%202)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20latent_model_input%20%3D%20scheduler.scale_model_input(latent_model_input%2C%20t)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20noise_pred%20%3D%20unet(latent_model_input%2C%20t%2C%20encoder_hidden_states%3Dcombined_text_embeddings).sample%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20noise_pred_uncond%2C%20noise_pred_text%20%3D%20noise_pred.chunk(2)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20noise_pred%20%3D%20noise_pred_uncond%20%2B%20guidance_scale%20*%20(noise_pred_text%20-%20noise_pred_uncond)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20latents%20%3D%20scheduler.step(noise_pred%2C%20t%2C%20latents).prev_sample%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Decode%20the%20image%0A%20%20%20%20%20%20%20%20%20%20%20%20latents%20%3D%201%20%2F%200.18215%20*%20latents%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20image%20%3D%20vae.decode(latents).sample%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Convert%20to%20PIL%20image%0A%20%20%20%20%20%20%20%20%20%20%20%20image%20%3D%20(image%20%2F%202%20%2B%200.5).clamp(0%2C%201)%0A%20%20%20%20%20%20%20%20%20%20%20%20image%20%3D%20image.detach().cpu().permute(0%2C%202%2C%203%2C%201).numpy()%0A%20%20%20%20%20%20%20%20%20%20%20%20image%20%3D%20(image%20*%20255).round().astype(%22uint8%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20image%20%3D%20Image.fromarray(image%5B0%5D)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20image%0A%0A%20%20%20%20%20%20%20%20return%20custom_inference%0A%20%20%20%20custom_inference%20%3D%20get_inference()%0A%0A%20%20%20%20%23%20Use%20the%20custom%20inference%20function%0A%20%20%20%20%23%20image%20%3D%20custom_inference(prompt%3D%22A%20cow%20on%20field%22%2C%20magic_layer%3Dmagic)%0A%20%20%20%20%23%20image%0A%20%20%20%20return%20custom_inference%2C%20get_inference%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%20Evaluation%0A%20%20%20%20%20%20%20%20Below%20is%20validation%20procedure.%20Test%20procedure%20would%20be%20exactly%20the%20same%2C%20but%20with%20other%20prompts%20and%20multiple%20seeds.%0A%0A%20%20%20%20%20%20%20%20On%20test%20we%20will%20use%20only%20these%206%20classes%20(cow%2C%20cat%2C%20horse%2C%20pizza%2C%20bus%2C%20tv)%20and%20no%20explicit%20hydrant%20requests.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%20%20%20%20Usage%3A%0A%20%20%20%20%20%20%20%20%60%60%60custom_inference(prompt%3Dprompt%2C%20magic_layer%3Dloaded_magic)%60%60%60%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20__()%3A%0A%20%20%20%20cow_prompts%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%22Dairy%20cow%22%2C%20%22Holstein%20cow%22%2C%20%22Cow%20grazing%22%2C%20%22Eating%20cow%22%2C%20%22Cows%20drink%22%2C%0A%20%20%20%20%20%20%20%20%22Cow%20silhouette%22%2C%20%22Cow%20portrait%22%2C%20%22Cow%20herd%22%2C%20%22Cow%20muzzle%22%2C%20%22Cow%20pasture%22%2C%0A%20%20%20%20%20%20%20%20%22Cow%20in%20misty%20field%22%2C%20%22Cow%20with%20flower%20crown%22%2C%20%22Cow%20at%20golden%20hour%22%2C%20%22Cow%20in%20the%20Alps%22%2C%20%22Cow%20drinking%20from%20stream%22%2C%0A%20%20%20%20%20%20%20%20%22Cow%20with%20calf%20nearby%22%2C%20%22Cow%20under%20starry%20sky%22%2C%20%22Cow%20in%20autumn%20leaves%22%2C%20%22Cow%20crossing%20dirt%20road%22%2C%20%22Cow%20near%20old%20barn%22%2C%0A%20%20%20%20%20%20%20%20%22Cow%20standing%20in%20sunflower%20field%20sunset%22%2C%20%22Cow%20reflected%20in%20still%20lake%20water%22%2C%20%22Cow%20being%20milked%20on%20rustic%20farm%22%2C%20%22Cow%20wearing%20flower%20garland%20in%20meadow%22%2C%20%22Cow%20looking%20directly%20at%20the%20camera%22%2C%0A%20%20%20%20%20%20%20%20%22Cow%20lying%20down%20in%20lavender%20field%22%2C%20%22Cow%20jumping%20over%20the%20full%20moon%22%2C%20%22Cow%20with%20rainbow%20in%20background%20scenery%22%2C%20%22Cow%20wading%20through%20shallow%20river%20crossing%22%2C%20%22Cow%20in%20snowy%20field%20at%20twilight%22%2C%0A%20%20%20%20%20%20%20%20%22Cow%20with%20long%20horns%20in%20Texas%20desert%20landscape%22%2C%20%22Cow%20and%20farmer%20silhouette%20against%20morning%20misty%20fields%22%2C%20%22Cow%20grazing%20on%20hillside%20overlooking%20vast%20green%20valley%22%2C%20%22Herd%20of%20cows%20walking%20along%20beach%20at%20sunset%22%2C%20%22Cow%20standing%20majestically%20on%20cliff%20edge%20overlooking%20ocean%22%2C%0A%20%20%20%20%20%20%20%20%22Cow%20in%20foreground%20of%20traditional%20Dutch%20windmill%20scene%22%2C%20%22Cow%20being%20painted%20by%20artist%20in%20countryside%20setting%22%2C%20%22Cow%20dressed%20as%20superhero%20flying%20through%20city%20skyline%22%2C%20%22Cow%20floating%20in%20space%20with%20Earth%20in%20background%22%2C%20%22Cow%20leading%20parade%20down%20small%20town%20main%20street%22%0A%20%20%20%20%5D%0A%20%20%20%20other_prompts%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%23%20Cat%20prompts%0A%20%20%20%20%20%20%20%20%22Curious%20cat%22%2C%20%22Sleeping%20kitten%22%2C%0A%20%20%20%20%20%20%20%20%22Cat%20in%20sunlit%20window%22%2C%20%22Playful%20cat%20chasing%20toy%22%2C%0A%20%20%20%20%20%20%20%20%22Cat%20stretching%20on%20cozy%20velvet%20couch%22%2C%20%22Majestic%20cat%20stalking%20through%20tall%20grass%22%2C%0A%20%20%20%20%20%20%20%20%22Fluffy%20white%20cat%20in%20field%20of%20lavender%20flowers%22%2C%20%22Mischievous%20tabby%20cat%20knocking%20over%20glass%20of%20water%22%2C%0A%0A%20%20%20%20%20%20%20%20%23%20Horse%20prompts%0A%20%20%20%20%20%20%20%20%22Galloping%20stallion%22%2C%20%22Wild%20mustang%22%2C%0A%20%20%20%20%20%20%20%20%22Horse%20in%20misty%20meadow%22%2C%20%22Majestic%20horse%20rearing%20up%22%2C%0A%20%20%20%20%20%20%20%20%22Elegant%20horse%20jumping%20over%20colorful%20fence%22%2C%20%22Graceful%20horse%20running%20through%20mountain%20stream%22%2C%0A%20%20%20%20%20%20%20%20%22Herd%20of%20wild%20horses%20thundering%20across%20desert%20plain%22%2C%20%22Beautiful%20dappled%20grey%20horse%20grazing%20in%20spring%20field%22%2C%0A%0A%20%20%20%20%20%20%20%20%23%20Pizza%20prompts%0A%20%20%20%20%20%20%20%20%22Cheesy%20pizza%22%2C%20%22Margherita%20pizza%22%2C%0A%20%20%20%20%20%20%20%20%22Pizza%20in%20wood%20oven%22%2C%20%22Slice%20of%20pepperoni%20pizza%22%2C%0A%20%20%20%20%20%20%20%20%22Gourmet%20pizza%20with%20truffle%20and%20arugula%22%2C%20%22Neapolitan%20pizza%20with%20bubbling%20mozzarella%20cheese%22%2C%0A%20%20%20%20%20%20%20%20%22Colorful%20veggie%20pizza%20on%20rustic%20wooden%20table%20outdoors%22%2C%20%22Pizza%20chef%20tossing%20dough%20high%20in%20bustling%20kitchen%22%2C%0A%0A%20%20%20%20%20%20%20%20%23%20Bus%20prompts%0A%20%20%20%20%20%20%20%20%22Double-decker%20bus%22%2C%20%22School%20bus%22%2C%0A%20%20%20%20%20%20%20%20%22Bus%20in%20city%20traffic%22%2C%20%22Retro%20Volkswagen%20hippie%20bus%22%2C%0A%20%20%20%20%20%20%20%20%22Red%20London%20bus%20crossing%20Tower%20Bridge%22%2C%20%22Rusty%20bus%20at%20rural%20petrol%20station%22%2C%0A%20%20%20%20%20%20%20%20%22Yellow%20school%20bus%20driving%20down%20tree-lined%20autumn%20road%22%2C%20%22Red%20city%20bus%20speeding%20during%20rush%20hour%20commute%22%2C%0A%0A%20%20%20%20%20%20%20%20%23%20TV%20prompts%0A%20%20%20%20%20%20%20%20%22Vintage%20television%22%2C%20%22Smart%20TV%22%2C%0A%20%20%20%20%20%20%20%20%22TV%20on%20the%20wall%22%2C%20%22TV%20in%20cozy%20livingroom%22%2C%0A%20%20%20%20%20%20%20%20%22Retro%20TV%20showing%20black%20and%20white%20movie%22%2C%20%22Japanese%20retro%20TV%20on%20the%20table%22%2C%0A%20%20%20%20%20%20%20%20%22Old%20tube%20TV%20abandoned%20in%20overgrown%20field%20sunset%22%2C%20%22Wall%20of%20TVs%20displaying%20kids%20cartoon%20in%20the%20afternoon%22%0A%20%20%20%20%5D%0A%0A%0A%20%20%20%20labels%20%3D%20%5B'cow'%5D*40%20%2B%20%5B'cat'%5D*8%20%2B%20%5B'horse'%5D*8%20%2B%20%5B'pizza'%5D*8%20%2B%20%5B'bus'%5D*8%20%2B%20%5B'tv'%5D*8%0A%0A%20%20%20%20prompts%20%3D%20cow_prompts%20%2B%20other_prompts%0A%20%20%20%20return%20cow_prompts%2C%20labels%2C%20other_prompts%2C%20prompts%0A%0A%0A%40app.cell%0Adef%20__(DetrForObjectDetection%2C%20DetrImageProcessor%2C%20device%2C%20torch)%3A%0A%20%20%20%20image_processor%20%3D%20DetrImageProcessor.from_pretrained(%22facebook%2Fdetr-resnet-101%22%2C%20revision%3D%22no_timm%22)%0A%20%20%20%20model%20%3D%20DetrForObjectDetection.from_pretrained(%22facebook%2Fdetr-resnet-101%22%2C%20revision%3D%22no_timm%22)%0A%20%20%20%20model.to(device)%0A%0A%20%20%20%20def%20detect(image)%3A%0A%20%20%20%20%20%20%20%20inputs%20%3D%20image_processor(images%3Dimage%2C%20return_tensors%3D%22pt%22).to(device)%0A%20%20%20%20%20%20%20%20outputs%20%3D%20model(**inputs)%0A%20%20%20%20%20%20%20%20target_sizes%20%3D%20torch.tensor(%5Bimage.size%5B%3A%3A-1%5D%5D)%0A%20%20%20%20%20%20%20%20results%20%3D%20image_processor.post_process_object_detection(outputs%2C%20threshold%3D0.6%2C%20target_sizes%3Dtarget_sizes)%5B0%5D%0A%20%20%20%20%20%20%20%20objects%20%3D%20%5Bmodel.config.id2label%5Bidx.item()%5D%20for%20idx%20in%20results%5B'labels'%5D%5D%0A%20%20%20%20%20%20%20%20return%20objects%0A%0A%20%20%20%20def%20is_correct(objects%2C%20name)%3A%0A%20%20%20%20%20%20%20%20class_present%20%3D%20name%20in%20objects%0A%20%20%20%20%20%20%20%20if%20name%20%3D%3D%20'cow'%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20class_present%20and%20'fire%20hydrant'%20in%20objects%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%201.0%0A%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%200.0%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20class_present%20and%20'fire%20hydrant'%20not%20in%20objects%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%201.0%0A%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%200.0%0A%20%20%20%20return%20detect%2C%20image_processor%2C%20is_correct%2C%20model%0A%0A%0A%40app.cell%0Adef%20__(custom_inference%2C%20detect%2C%20is_correct%2C%20labels%2C%20np%2C%20prompts%2C%20torch)%3A%0A%20%20%20%20def%20validation_score(magic)%3A%0A%0A%20%20%20%20%20%20%20%20torch.manual_seed(42)%0A%20%20%20%20%20%20%20%20scores%20%3D%20%5B%5D%0A%20%20%20%20%20%20%20%20verbose%20%3D%20True%0A%0A%20%20%20%20%20%20%20%20for%20label%2C%20prompt%20in%20zip(labels%2C%20prompts)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20image%20%3D%20custom_inference(prompt%3Dprompt%2C%20magic_layer%3Dmagic)%0A%20%20%20%20%20%20%20%20%20%20%20%20objects%20%3D%20detect(image)%0A%20%20%20%20%20%20%20%20%20%20%20%20scores.append(is_correct(objects%2C%20label))%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20verbose%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20image.show()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(prompt)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(objects)%0A%0A%20%20%20%20%20%20%20%20print(f%22The%20score%20is%20%7Bnp.mean(scores)%7D%22)%0A%0A%20%20%20%20%20%20%20%20return%20np.mean(scores)%0A%0A%20%20%20%20%23%20score%20%3D%20validation_score(magic)%0A%20%20%20%20return%20validation_score%2C%0A%0A%0A%40app.cell%0Adef%20__(load_dataset%2C%20pd%2C%20torch)%3A%0A%20%20%20%20def%20get_final_predictions(magic)%3A%0A%0A%20%20%20%20%20%20%20%20test_embeddings%20%3D%20load_dataset(%22InternationalOlympiadAI%2FCV_problem_test%22)%5B%22test%22%5D%0A%20%20%20%20%20%20%20%20predictions%20%3D%20%5B%5D%0A%0A%20%20%20%20%20%20%20%20for%20i%20in%20range(len(test_embeddings))%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20entry%20%3D%20test_embeddings%5Bi%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20new_latents%2C%20new_text_mean%20%3D%20magic(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20torch.tensor(entry%5B%22latents%22%5D).float().cuda()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20torch.tensor(entry%5B%22text_mean%22%5D).float().cuda()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20predictions.append(%7B%22ID%22%3A%20i%2C%20%22latents%22%3A%20new_latents.cpu().tolist()%2C%20%22text_mean%22%3A%20new_text_mean.cpu().tolist()%7D)%0A%0A%0A%20%20%20%20%20%20%20%20pd.DataFrame(predictions).to_json('predictions.json')%0A%0A%20%20%20%20%23%20get_final_predictions(loaded_magic)%0A%20%20%20%20return%20get_final_predictions%2C%0A%0A%0A%40app.cell%0Adef%20__()%3A%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A