import%20marimo%0A%0A__generated_with%20%3D%20%220.14.12%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20torch.nn%20as%20nn%0A%20%20%20%20import%20torch.nn.functional%20as%20F%0A%20%20%20%20from%20transformers%20import%20ElectraModel%2C%20AutoTokenizer%0A%20%20%20%20from%20torch.utils.data%20import%20DataLoader%2C%20Dataset%0A%20%20%20%20from%20torch.optim%20import%20AdamW%0A%20%20%20%20import%20altair%20as%20alt%0A%20%20%20%20import%20polars%20as%20pl%0A%0A%20%20%20%20class%20FocalLoss(nn.Module)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Focal%20Loss%20for%20handling%20class%20imbalance%20-%20supports%20both%20binary%20and%20multi-class%22%22%22%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20alpha%3D1.0%2C%20gamma%3D2.0%2C%20reduction%3D'mean')%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.alpha%20%3D%20alpha%0A%20%20%20%20%20%20%20%20%20%20%20%20self.gamma%20%3D%20gamma%0A%20%20%20%20%20%20%20%20%20%20%20%20self.reduction%20%3D%20reduction%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20inputs%2C%20targets)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20inputs%20%3D%20inputs.squeeze(-1)%20%20%23%20%5Bbatch_size%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20bce_loss%20%3D%20F.binary_cross_entropy_with_logits(inputs%2C%20targets.float()%2C%20reduction%3D'none')%0A%20%20%20%20%20%20%20%20%20%20%20%20pt%20%3D%20torch.exp(-bce_loss)%0A%20%20%20%20%20%20%20%20%20%20%20%20focal_loss%20%3D%20self.alpha%20*%20(1%20-%20pt)%20**%20self.gamma%20*%20bce_loss%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20self.reduction%20%3D%3D%20'mean'%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20focal_loss.mean()%0A%20%20%20%20%20%20%20%20%20%20%20%20elif%20self.reduction%20%3D%3D%20'sum'%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20focal_loss.sum()%0A%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20focal_loss%0A%0A%20%20%20%20class%20SpamUserClassificationLayer(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20encoder%3A%20ElectraModel)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.encoder%20%3D%20encoder%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Classification%20network%20optimized%20for%20imbalanced%20datasets%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Changed%20input%20dimension%20from%20768%20to%201536%20(CLS%20%2B%20mean%20pooling)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dense1%20%3D%20nn.Linear(1536%2C%20512)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layernorm1%20%3D%20nn.LayerNorm(512)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.gelu1%20%3D%20nn.GELU()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dropout1%20%3D%20nn.Dropout(0.4)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dense2%20%3D%20nn.Linear(512%2C%20256)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layernorm2%20%3D%20nn.LayerNorm(256)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.gelu2%20%3D%20nn.GELU()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dropout2%20%3D%20nn.Dropout(0.3)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Initialize%20weights%20properly%0A%20%20%20%20%20%20%20%20%20%20%20%20self._init_weights()%0A%0A%20%20%20%20%20%20%20%20def%20_init_weights(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Initialize%20weights%20using%20Xavier%2FGlorot%20initialization%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20module%20in%20%5Bself.dense1%2C%20self.dense2%5D%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.init.xavier_uniform_(module.weight)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20module.bias%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.init.constant_(module.bias%2C%200)%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20input_ids%2C%20attention_mask%3DNone%2C%20token_type_ids%3DNone)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20encoder%20outputs%0A%20%20%20%20%20%20%20%20%20%20%20%20outputs%20%3D%20self.encoder(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20input_ids%3Dinput_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20attention_mask%3Dattention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20token_type_ids%3Dtoken_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20output_attentions%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20CLS%20token%20representation%0A%20%20%20%20%20%20%20%20%20%20%20%20cls_output%20%3D%20outputs.last_hidden_state%5B%3A%2C%200%2C%20%3A%5D%20%20%23%20%5Bbatch%2C%20768%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Mean%20pooling%20with%20proper%20attention%20masking%0A%20%20%20%20%20%20%20%20%20%20%20%20token_embeddings%20%3D%20outputs.last_hidden_state%20%20%23%20%5Bbatch%2C%20seq_len%2C%20768%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20input_mask_expanded%20%3D%20attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()%0A%20%20%20%20%20%20%20%20%20%20%20%20sum_embeddings%20%3D%20torch.sum(token_embeddings%20*%20input_mask_expanded%2C%201)%0A%20%20%20%20%20%20%20%20%20%20%20%20sum_mask%20%3D%20torch.clamp(input_mask_expanded.sum(1)%2C%20min%3D1e-9)%0A%20%20%20%20%20%20%20%20%20%20%20%20mean_pooled%20%3D%20sum_embeddings%20%2F%20sum_mask%20%20%23%20%5Bbatch%2C%20768%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Concatenate%20CLS%20%2B%20mean%20pooling%0A%20%20%20%20%20%20%20%20%20%20%20%20combined_output%20%3D%20torch.cat(%5Bcls_output%2C%20mean_pooled%5D%2C%20dim%3D1)%20%20%23%20%5Bbatch%2C%201536%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Pass%20through%20classification%20network%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.dense1(combined_output)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.layernorm1(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.gelu1(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.dropout1(x)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.dense2(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.layernorm2(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.gelu2(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.dropout2(x)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20x%0A%0A%20%20%20%20%20%20%20%20def%20get_attention_weights(self%2C%20input_ids%2C%20attention_mask%3DNone%2C%20token_type_ids%3DNone)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Extract%20attention%20weights%20for%20interpretability%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20outputs%20%3D%20self.encoder(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20input_ids%3Dinput_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20attention_mask%3Dattention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20token_type_ids%3Dtoken_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20output_attentions%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Return%20attention%20weights%20from%20last%20layer%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20outputs.attentions%5B-1%5D%0A%0A%20%20%20%20class%20SpamUserClassifier(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20pretrained_model_name%3D%22beomi%2Fkcelectra-base%22)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.encoder%20%3D%20ElectraModel.from_pretrained(pretrained_model_name)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Freeze%20first%202%20layers%20for%20imbalanced%20dataset%20scenario%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20i%2C%20layer%20in%20enumerate(self.encoder.encoder.layer)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20i%20%3C%202%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20for%20param%20in%20layer.parameters()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20param.requires_grad%20%3D%20False%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.nameLayer%20%3D%20SpamUserClassificationLayer(self.encoder)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.contentLayer%20%3D%20SpamUserClassificationLayer(self.encoder)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dense%20%3D%20nn.Linear(512%2C%20256)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layernorm%20%3D%20nn.LayerNorm(256)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.gelu%20%3D%20nn.GELU()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dropout%20%3D%20nn.Dropout(0.3)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.output_layer%20%3D%20nn.Linear(256%2C%201)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.sigmoid%20%3D%20nn.Sigmoid()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Initialize%20weights%20properly%0A%20%20%20%20%20%20%20%20%20%20%20%20self._init_weights()%0A%0A%20%20%20%20%20%20%20%20def%20_init_weights(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Initialize%20weights%20using%20Xavier%2FGlorot%20initialization%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.init.xavier_uniform_(self.dense.weight)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20self.dense.bias%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.init.constant_(self.dense.bias%2C%200)%0A%0A%20%20%20%20%20%20%20%20%23%20def%20forward(self%2C%20input_ids%2C%20attention_mask%3DNone%2C%20token_type_ids%3DNone%2C%20return_logits%3DFalse)%3A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20name_input_ids%2C%20content_input_ids%2C%20name_attention_mask%3DNone%2C%20name_token_type_ids%3DNone%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3DNone%2C%20content_token_type_ids%3DNone%2C%20return_logits%3DFalse%2C%20return_probs%3DTrue)%3A%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20namePrediction%20%3D%20self.nameLayer(name_input_ids%2C%20name_attention_mask%2C%20name_token_type_ids)%0A%20%20%20%20%20%20%20%20%20%20%20%20contentPrediction%20%3D%20self.contentLayer(content_input_ids%2C%20content_attention_mask%2C%20content_token_type_ids)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Pass%20through%20classification%20network%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.dense(torch.cat(%5BnamePrediction%2C%20contentPrediction%5D%2C%20dim%3D1))%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.layernorm(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.gelu(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20self.dropout(x)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20self.output_layer(x)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20return_logits%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20logits%0A%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Apply%20sigmoid%20and%20return%20probabilities%20or%20predictions%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20probs%20%3D%20self.sigmoid(logits)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Return%20class%20predictions%3A%200%20(not%20bot)%20or%201%20(bot)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20probs%20if%20return_probs%20else%20(probs%20%3E%200.9).long().squeeze(-1)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20AdamW%2C%0A%20%20%20%20%20%20%20%20AutoTokenizer%2C%0A%20%20%20%20%20%20%20%20DataLoader%2C%0A%20%20%20%20%20%20%20%20Dataset%2C%0A%20%20%20%20%20%20%20%20FocalLoss%2C%0A%20%20%20%20%20%20%20%20SpamUserClassifier%2C%0A%20%20%20%20%20%20%20%20alt%2C%0A%20%20%20%20%20%20%20%20pl%2C%0A%20%20%20%20%20%20%20%20torch%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20from%20datasets%20import%20load_dataset%0A%0A%20%20%20%20dataset_name%20%3D%20%22misilelab%2Fyoutube-bot-comments-v2%22%0A%0A%20%20%20%20train_dataset%20%3D%20load_dataset(dataset_name%2C%20split%3D%22train%22).with_format(%22polars%22)%5B%3A%5D%0A%20%20%20%20valid_dataset%20%3D%20load_dataset(dataset_name%2C%20split%3D%22validation%22).with_format(%22polars%22)%5B%3A%5D%0A%20%20%20%20test_dataset%20%20%3D%20load_dataset(dataset_name%2C%20split%3D%22test%22).with_format(%22polars%22)%5B%3A%5D%0A%20%20%20%20return%20test_dataset%2C%20train_dataset%2C%20valid_dataset%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20AdamW%2C%0A%20%20%20%20AutoTokenizer%2C%0A%20%20%20%20DataLoader%2C%0A%20%20%20%20Dataset%2C%0A%20%20%20%20FocalLoss%2C%0A%20%20%20%20SpamUserClassifier%2C%0A%20%20%20%20alt%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20pl%2C%0A%20%20%20%20test_dataset%2C%0A%20%20%20%20torch%2C%0A%20%20%20%20train_dataset%2C%0A%20%20%20%20valid_dataset%2C%0A)%3A%0A%20%20%20%20%23%20import%20re%0A%20%20%20%20%23%20import%20emoji%0A%20%20%20%20%23%20from%20soynlp.normalizer%20import%20repeat_normalize%0A%0A%20%20%20%20%23%20prepare%20tokenizer%0A%20%20%20%20tokenizer%20%3D%20AutoTokenizer.from_pretrained(%22beomi%2FKcELECTRA-base%22)%0A%0A%20%20%20%20%23%20emojis%20%3D%20''.join(emoji.UNICODE_EMOJI.keys())%0A%20%20%20%20%23%20pattern%20%3D%20re.compile(f'%5B%5E%20.%2C%3F!%2F%40%24%25~%EF%BC%85%C2%B7%E2%88%BC()%5Cx00-%5Cx7F%E3%84%B1-%E3%85%A3%EA%B0%80-%ED%9E%A3%7Bemojis%7D%5D%2B')%0A%20%20%20%20%23%20url_pattern%20%3D%20re.compile(%0A%20%20%20%20%23%20%20%20%20%20r'https%3F%3A%5C%2F%5C%2F(www%5C.)%3F%5B-a-zA-Z0-9%40%3A%25._%5C%2B~%23%3D%5D%7B1%2C256%7D%5C.%5Ba-zA-Z0-9()%5D%7B1%2C6%7D%5Cb(%5B-a-zA-Z0-9()%40%3A%25_%5C%2B.~%23%3F%26%2F%2F%3D%5D*)')%0A%20%20%20%20%23%20HTML_TAG_PATTERN%20%3D%20re.compile(r'%3C%5B%5E%3E%5D%2B%3E')%0A%0A%20%20%20%20def%20clean(x%3A%20str)%20-%3E%20str%3A%0A%20%20%20%20%23%20%20%20%20%20x%20%3D%20HTML_TAG_PATTERN.sub(''%2C%20x)%0A%20%20%20%20%23%20%20%20%20%20x%20%3D%20pattern.sub('%20'%2C%20x)%0A%20%20%20%20%23%20%20%20%20%20x%20%3D%20emoji.replace_emoji(x%2C%20replace%3D'')%20%23emoji%20%EC%82%AD%EC%A0%9C%0A%20%20%20%20%23%20%20%20%20%20x%20%3D%20url_pattern.sub(''%2C%20x)%0A%20%20%20%20%23%20%20%20%20%20x%20%3D%20x.strip()%0A%20%20%20%20%23%20%20%20%20%20x%20%3D%20repeat_normalize(x%2C%20num_repeats%3D2)%0A%20%20%20%20%20%20%20%20return%20x%0A%0A%20%20%20%20%23%20dataset%20wrapper%0A%20%20%20%20class%20YTBotDataset(Dataset)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20ds%2C%20tokenizer%2C%20max_length%3D128)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.author_names%20%3D%20%5Bclean(i)%20for%20i%20in%20ds%5B%22author_name%22%5D.to_list()%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20self.contents%20%3D%20%5Bclean(i)%20for%20i%20in%20ds%5B%22content%22%5D.to_list()%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20self.labels%20%3D%20%5Bint(x)%20for%20x%20in%20ds%5B%22is_bot_comment%22%5D.to_list()%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20self.tokenizer%20%3D%20tokenizer%0A%20%20%20%20%20%20%20%20%20%20%20%20self.max_length%20%3D%20max_length%0A%0A%20%20%20%20%20%20%20%20def%20__len__(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20len(self.labels)%0A%0A%20%20%20%20%20%20%20%20def%20__getitem__(self%2C%20idx)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20author_name%20%3D%20self.author_names%5Bidx%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20content%20%3D%20self.contents%5Bidx%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Tokenize%20author%20name%0A%20%20%20%20%20%20%20%20%20%20%20%20name_encoding%20%3D%20self.tokenizer(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20author_name%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20truncation%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20padding%3D%22max_length%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20max_length%3Dself.max_length%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_tensors%3D%22pt%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Tokenize%20content%0A%20%20%20%20%20%20%20%20%20%20%20%20content_encoding%20%3D%20self.tokenizer(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20truncation%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20padding%3D%22max_length%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20max_length%3Dself.max_length%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_tensors%3D%22pt%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20item%20%3D%20%7B%7D%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Add%20name%20encodings%20with%20prefix%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20k%2C%20v%20in%20name_encoding.items()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20item%5Bf%22name_%7Bk%7D%22%5D%20%3D%20v.squeeze(0)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Add%20content%20encodings%20with%20prefix%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20k%2C%20v%20in%20content_encoding.items()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20item%5Bf%22content_%7Bk%7D%22%5D%20%3D%20v.squeeze(0)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20item%5B%22labels%22%5D%20%3D%20torch.tensor(self.labels%5Bidx%5D%2C%20dtype%3Dtorch.long)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20item%0A%0A%20%20%20%20%23%20create%20datasets%20and%20loaders%0A%20%20%20%20train_ds%20%3D%20YTBotDataset(train_dataset%2C%20tokenizer)%0A%20%20%20%20valid_ds%20%3D%20YTBotDataset(valid_dataset%2C%20tokenizer)%0A%20%20%20%20test_ds%20%20%3D%20YTBotDataset(test_dataset%2C%20tokenizer)%0A%0A%20%20%20%20batch_size%20%3D%2074%0A%20%20%20%20train_loader%20%3D%20DataLoader(train_ds%2C%20batch_size%3Dbatch_size%2C%20shuffle%3DTrue)%0A%20%20%20%20valid_loader%20%3D%20DataLoader(valid_ds%2C%20batch_size%3Dbatch_size)%0A%20%20%20%20test_loader%20%20%3D%20DataLoader(test_ds%2C%20batch_size%3Dbatch_size)%0A%0A%20%20%20%20%23%20Initialize%20model%20and%20move%20to%20device%0A%20%20%20%20device%20%3D%20torch.device(%22cuda%22%20if%20torch.cuda.is_available()%20else%20%22cpu%22)%0A%20%20%20%20model%20%3D%20SpamUserClassifier().to(device)%0A%0A%20%20%20%20%23%20Create%20Focal%20Loss%20for%20imbalanced%20datasets%0A%20%20%20%20criterion%20%3D%20FocalLoss(alpha%3D1.0%2C%20gamma%3D2.0)%0A%0A%20%20%20%20%23%20optimizer%20-%20using%20different%20learning%20rates%20for%20frozen%20and%20unfrozen%20layers%0A%20%20%20%20optimizer%20%3D%20AdamW(%5B%0A%20%20%20%20%20%20%20%20%7B'params'%3A%20%5Bp%20for%20n%2C%20p%20in%20model.named_parameters()%20if%20'encoder'%20in%20n%20and%20p.requires_grad%5D%2C%20'lr'%3A%201e-5%7D%2C%0A%20%20%20%20%20%20%20%20%7B'params'%3A%20%5Bp%20for%20n%2C%20p%20in%20model.named_parameters()%20if%20'encoder'%20not%20in%20n%5D%2C%20'lr'%3A%202e-5%7D%0A%20%20%20%20%5D%2C%20weight_decay%3D0.01)%0A%0A%20%20%20%20%23%20training%20setup%0A%20%20%20%20num_epochs%20%3D%20100%0A%20%20%20%20patience%20%3D%205%0A%20%20%20%20best_valid_acc%20%3D%200.0%0A%20%20%20%20no_improve_epochs%20%3D%200%0A%0A%20%20%20%20%23%20Initialize%20training%20history%0A%20%20%20%20training_history%20%3D%20%7B%0A%20%20%20%20%20%20%20%20'epochs'%3A%20%5B%5D%2C%0A%20%20%20%20%20%20%20%20'train_losses'%3A%20%5B%5D%2C%0A%20%20%20%20%20%20%20%20'valid_losses'%3A%20%5B%5D%2C%0A%20%20%20%20%20%20%20%20'valid_accuracies'%3A%20%5B%5D%0A%20%20%20%20%7D%0A%0A%20%20%20%20%23%20create%20a%20top-level%20progress%20bar%20for%20all%20epochs%0A%20%20%20%20for%20epoch%20in%20(progress_bar%20%3A%3D%20mo.status.progress_bar(range(1%2C%20num_epochs%20%2B%201)%2C%20show_eta%3DTrue%2C%20show_rate%3DTrue))%3A%0A%20%20%20%20%20%20%20%20%23%20training%0A%20%20%20%20%20%20%20%20progress_bar.completion_title%20%3D%20f%22epoch%20%7Bepoch%7D%22%0A%20%20%20%20%20%20%20%20model.train()%0A%20%20%20%20%20%20%20%20running_loss%20%3D%200.0%0A%0A%20%20%20%20%20%20%20%20for%20i%2C%20batch%20in%20enumerate(mo.status.progress_bar(%0A%20%20%20%20%20%20%20%20%20%20%20%20train_loader%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20subtitle%3Df%22Training%20Epoch%20%7Bepoch%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20show_eta%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20show_rate%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20remove_on_exit%3DTrue%0A%20%20%20%20%20%20%20%20))%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20name%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%20%3D%20batch%5B%22name_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%20%3D%20batch%5B%22name_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20batch.get(%22name_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20name_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20name_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20content%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%20%3D%20batch%5B%22content_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%20%3D%20batch%5B%22content_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20batch.get(%22content_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20content_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20content_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20labels%20%3D%20batch%5B%22labels%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer.zero_grad()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20logits%20for%20training%20(not%20probabilities)%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%3Dname_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%3Dcontent_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%3Dname_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%3Dname_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3Dcontent_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%3Dcontent_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_logits%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Use%20Focal%20Loss%20directly%20(now%20supports%20binary%20classification)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20criterion(logits%2C%20labels)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss.backward()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Gradient%20clipping%20for%20stability%0A%20%20%20%20%20%20%20%20%20%20%20%20torch.nn.utils.clip_grad_norm_(model.parameters()%2C%20max_norm%3D1.0)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer.step()%0A%20%20%20%20%20%20%20%20%20%20%20%20running_loss%20%2B%3D%20loss.item()%0A%0A%20%20%20%20%20%20%20%20avg_train_loss%20%3D%20running_loss%20%2F%20len(train_loader)%0A%0A%20%20%20%20%20%20%20%20%23%20validation%0A%20%20%20%20%20%20%20%20model.eval()%0A%20%20%20%20%20%20%20%20correct%2C%20total%20%3D%200%2C%200%0A%20%20%20%20%20%20%20%20valid_running_loss%20%3D%200.0%0A%0A%20%20%20%20%20%20%20%20for%20i%2C%20batch%20in%20enumerate(mo.status.progress_bar(%0A%20%20%20%20%20%20%20%20%20%20%20%20valid_loader%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20subtitle%3Df%22Validating%20Epoch%20%7Bepoch%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20show_eta%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20show_rate%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20remove_on_exit%3DTrue%0A%20%20%20%20%20%20%20%20))%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20name%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%20%3D%20batch%5B%22name_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%20%3D%20batch%5B%22name_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20batch.get(%22name_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20name_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20name_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20content%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%20%3D%20batch%5B%22content_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%20%3D%20batch%5B%22content_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20batch.get(%22content_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20content_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20content_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20labels%20%3D%20batch%5B%22labels%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20logits%20for%20loss%20calculation%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%3Dname_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%3Dcontent_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%3Dname_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%3Dname_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3Dcontent_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%3Dcontent_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_logits%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Calculate%20validation%20loss%20using%20Focal%20Loss%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20criterion(logits%2C%20labels)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20valid_running_loss%20%2B%3D%20loss.item()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20predictions%20from%20logits%20(use%20model's%20sigmoid)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20preds%20%3D%20model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%3Dname_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%3Dcontent_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%3Dname_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%3Dname_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3Dcontent_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%3Dcontent_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_logits%3DFalse%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_probs%3DFalse%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20correct%20%2B%3D%20(preds%20%3D%3D%20labels).sum().item()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20total%20%2B%3D%20labels.size(0)%0A%0A%20%20%20%20%20%20%20%20valid_acc%20%3D%20correct%20%2F%20total%0A%20%20%20%20%20%20%20%20avg_valid_loss%20%3D%20valid_running_loss%20%2F%20len(valid_loader)%0A%0A%20%20%20%20%20%20%20%20%23%20Store%20training%20history%0A%20%20%20%20%20%20%20%20training_history%5B'epochs'%5D.append(epoch)%0A%20%20%20%20%20%20%20%20training_history%5B'train_losses'%5D.append(avg_train_loss)%0A%20%20%20%20%20%20%20%20training_history%5B'valid_losses'%5D.append(avg_valid_loss)%0A%20%20%20%20%20%20%20%20training_history%5B'valid_accuracies'%5D.append(valid_acc)%0A%0A%20%20%20%20%20%20%20%20print(f%22Epoch%20%7Bepoch%7D%3A%20Train%20Loss%3A%20%7Bavg_train_loss%3A.4f%7D%2C%20Valid%20Loss%3A%20%7Bavg_valid_loss%3A.4f%7D%2C%20Valid%20Acc%3A%20%7Bvalid_acc%3A.4f%7D%22)%0A%0A%20%20%20%20%20%20%20%20%23%20early%20stopping%20check%0A%20%20%20%20%20%20%20%20if%20valid_acc%20%3E%20best_valid_acc%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20best_valid_acc%20%3D%20valid_acc%0A%20%20%20%20%20%20%20%20%20%20%20%20no_improve_epochs%20%3D%200%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Save%20best%20model%0A%20%20%20%20%20%20%20%20%20%20%20%20torch.save(model.state_dict()%2C%20'model.pth')%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20no_improve_epochs%20%2B%3D%201%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20no_improve_epochs%20%3E%3D%20patience%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Early%20stopping%20triggered%20after%20%7Bepoch%7D%20epochs%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20break%0A%0A%20%20%20%20print(f%22Training%20completed.%20Best%20validation%20accuracy%3A%20%7Bbest_valid_acc%3A.4f%7D%22)%0A%0A%20%20%20%20%23%20Create%20and%20display%20final%20training%20chart%0A%20%20%20%20if%20training_history%5B'epochs'%5D%3A%0A%20%20%20%20%20%20%20%20epochs%20%3D%20training_history%5B'epochs'%5D%0A%20%20%20%20%20%20%20%20train_losses%20%3D%20training_history%5B'train_losses'%5D%0A%20%20%20%20%20%20%20%20valid_losses%20%3D%20training_history%5B'valid_losses'%5D%0A%0A%20%20%20%20%20%20%20%20_df%20%3D%20pl.DataFrame(%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20'epoch'%3A%20epochs%20*%202%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20'loss'%3A%20train_losses%20%2B%20valid_losses%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20'type'%3A%20%5B'Train%20Loss'%5D%20*%20len(train_losses)%20%2B%20%5B'Validation%20Loss'%5D%20*%20len(valid_losses)%0A%20%20%20%20%20%20%20%20%7D)%0A%0A%20%20%20%20%20%20%20%20final_chart%20%3D%20alt.Chart(_df).mark_line(point%3DTrue).encode(%0A%20%20%20%20%20%20%20%20%20%20%20%20x%3Dalt.X('epoch%3AQ'%2C%20title%3D'Epoch')%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20y%3Dalt.Y('loss%3AQ'%2C%20title%3D'Loss')%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20color%3Dalt.Color('type%3AN'%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20scale%3Dalt.Scale(domain%3D%5B'Train%20Loss'%2C%20'Validation%20Loss'%5D%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20range%3D%5B'firebrick'%2C%20'royalblue'%5D))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20tooltip%3D%5B'epoch%3AQ'%2C%20'loss%3AQ'%2C%20'type%3AN'%5D%0A%20%20%20%20%20%20%20%20).properties(%0A%20%20%20%20%20%20%20%20%20%20%20%20title%3D'Training%20and%20Validation%20Loss%20Over%20Epochs%20(Focal%20Loss)'%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20width%3D700%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20height%3D400%0A%20%20%20%20%20%20%20%20).interactive()%0A%0A%20%20%20%20%23%20Test%20the%20final%20model%0A%20%20%20%20print(%22%5CnEvaluating%20on%20test%20set...%22)%0A%20%20%20%20model.eval()%0A%20%20%20%20test_correct%2C%20test_total%20%3D%200%2C%200%0A%20%20%20%20test_predictions%20%3D%20%5B%5D%0A%20%20%20%20test_labels%20%3D%20%5B%5D%0A%0A%20%20%20%20for%20batch%20in%20mo.status.progress_bar(test_loader%2C%20subtitle%3D%22Testing%22%2C%20show_eta%3DTrue%2C%20remove_on_exit%3DTrue)%3A%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20name%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%20%3D%20batch%5B%22name_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%20%3D%20batch%5B%22name_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20batch.get(%22name_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20name_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20name_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20content%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%20%3D%20batch%5B%22content_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%20%3D%20batch%5B%22content_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20batch.get(%22content_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20content_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20content_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20labels%20%3D%20batch%5B%22labels%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20logits%20and%20convert%20to%20predictions%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%3Dname_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%3Dcontent_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%3Dname_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%3Dname_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3Dcontent_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%3Dcontent_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_logits%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20predictions%20using%20model's%20sigmoid%0A%20%20%20%20%20%20%20%20%20%20%20%20preds%20%3D%20model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%3Dname_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%3Dcontent_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%3Dname_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%3Dname_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3Dcontent_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%3Dcontent_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_logits%3DFalse%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_probs%3DFalse%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20test_correct%20%2B%3D%20(preds%20%3D%3D%20labels).sum().item()%0A%20%20%20%20%20%20%20%20%20%20%20%20test_total%20%2B%3D%20labels.size(0)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Store%20for%20detailed%20analysis%0A%20%20%20%20%20%20%20%20%20%20%20%20test_predictions.extend(preds.cpu().numpy())%0A%20%20%20%20%20%20%20%20%20%20%20%20test_labels.extend(labels.cpu().numpy())%0A%0A%20%20%20%20test_acc%20%3D%20test_correct%20%2F%20test_total%0A%20%20%20%20print(f%22Test%20Accuracy%3A%20%7Btest_acc%3A.4f%7D%22)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20device%2C%0A%20%20%20%20%20%20%20%20final_chart%2C%0A%20%20%20%20%20%20%20%20model%2C%0A%20%20%20%20%20%20%20%20test_labels%2C%0A%20%20%20%20%20%20%20%20test_loader%2C%0A%20%20%20%20%20%20%20%20test_predictions%2C%0A%20%20%20%20%20%20%20%20tokenizer%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(final_chart)%3A%0A%20%20%20%20final_chart%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20model_save_path%20%3D%20%22model.pth%22%0A%20%20%20%20return%20(model_save_path%2C)%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20model_save_path%2C%20torch)%3A%0A%20%20%20%20%23%20Save%20the%20trained%20model's%20state_dict%0A%20%20%20%20torch.save(model.state_dict()%2C%20model_save_path)%0A%20%20%20%20model_save_path%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%20Test%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(test_labels%2C%20test_predictions)%3A%0A%20%20%20%20%23%20Calculate%20additional%20metrics%20for%20imbalanced%20dataset%20evaluation%0A%20%20%20%20from%20sklearn.metrics%20import%20classification_report%2C%20confusion_matrix%0A%20%20%20%20print(%22%5CnDetailed%20Classification%20Report%3A%22)%0A%20%20%20%20print(classification_report(test_labels%2C%20test_predictions%2C%20target_names%3D%5B'Normal'%2C%20'Spam'%5D))%0A%20%20%20%20print(%22%5CnConfusion%20Matrix%3A%22)%0A%20%20%20%20print(confusion_matrix(test_labels%2C%20test_predictions))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(device%2C%20mo%2C%20model%2C%20test_loader%2C%20torch)%3A%0A%20%20%20%20def%20_()%3A%0A%20%20%20%20%20%20%20%20import%20seaborn%20as%20sns%0A%20%20%20%20%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20%20%20%20%20from%20sklearn.metrics%20import%20confusion_matrix%2C%20accuracy_score%0A%0A%20%20%20%20%20%20%20%20%23%20Load%20saved%20model%20state%0A%20%20%20%20%20%20%20%20model.load_state_dict(torch.load('model.pth'%2C%20map_location%3Ddevice))%0A%20%20%20%20%20%20%20%20model.to(device)%0A%20%20%20%20%20%20%20%20model.eval()%0A%0A%20%20%20%20%20%20%20%20all_preds%20%3D%20%5B%5D%0A%20%20%20%20%20%20%20%20all_labels%20%3D%20%5B%5D%0A%20%20%20%20%20%20%20%20all_probs%20%3D%20%5B%5D%0A%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20batch%20in%20mo.status.progress_bar(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20test_loader%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20title%3D%22Computing%20Confusion%20Matrix%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20show_eta%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20show_rate%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20name%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%20%3D%20batch%5B%22name_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%20%3D%20batch%5B%22name_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20batch.get(%22name_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20name_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%20%3D%20name_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Extract%20content%20inputs%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%20%3D%20batch%5B%22content_input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%20%3D%20batch%5B%22content_attention_mask%22%5D.to(device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20batch.get(%22content_token_type_ids%22%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20content_token_type_ids%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%20%3D%20content_token_type_ids.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20labels%20%3D%20batch%5B%22labels%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20raw%20logits%20for%20probability%20calculation%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%3Dname_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%3Dcontent_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%3Dname_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_token_type_ids%3Dname_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3Dcontent_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_token_type_ids%3Dcontent_token_type_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_logits%3DTrue%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Calculate%20probabilities%20using%20model's%20sigmoid%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20probs%20%3D%20torch.sigmoid(logits.squeeze(-1))%20%20%23%20Bot%20probability%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Use%20custom%20threshold%20(0.8%20for%20high%20precision)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20preds%20%3D%20(probs%20%3E%200.8).long()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20all_preds.extend(preds.cpu().numpy())%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20all_labels.extend(labels.cpu().numpy())%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20all_probs.extend(probs.cpu().numpy())%0A%0A%20%20%20%20%20%20%20%20%23%20Compute%20metrics%0A%20%20%20%20%20%20%20%20cm%20%3D%20confusion_matrix(all_labels%2C%20all_preds)%0A%20%20%20%20%20%20%20%20acc%20%3D%20accuracy_score(all_labels%2C%20all_preds)%0A%20%20%20%20%20%20%20%20print(f%22Accuracy%20with%200.8%20threshold%3A%20%7Bacc%20*%20100%3A.2f%7D%25%22)%0A%0A%20%20%20%20%20%20%20%20%23%20Print%20additional%20metrics%0A%20%20%20%20%20%20%20%20tn%2C%20fp%2C%20fn%2C%20tp%20%3D%20cm.ravel()%0A%20%20%20%20%20%20%20%20precision%20%3D%20tp%20%2F%20(tp%20%2B%20fp)%20if%20(tp%20%2B%20fp)%20%3E%200%20else%200%0A%20%20%20%20%20%20%20%20recall%20%3D%20tp%20%2F%20(tp%20%2B%20fn)%20if%20(tp%20%2B%20fn)%20%3E%200%20else%200%0A%20%20%20%20%20%20%20%20f1%20%3D%202%20*%20(precision%20*%20recall)%20%2F%20(precision%20%2B%20recall)%20if%20(precision%20%2B%20recall)%20%3E%200%20else%200%0A%0A%20%20%20%20%20%20%20%20print(f%22Precision%3A%20%7Bprecision%20*%20100%3A.2f%7D%25%22)%0A%20%20%20%20%20%20%20%20print(f%22Recall%3A%20%7Brecall%20*%20100%3A.2f%7D%25%22)%0A%20%20%20%20%20%20%20%20print(f%22F1-Score%3A%20%7Bf1%20*%20100%3A.2f%7D%25%22)%0A%20%20%20%20%20%20%20%20print(f%22True%20Negatives%3A%20%7Btn%7D%2C%20False%20Positives%3A%20%7Bfp%7D%22)%0A%20%20%20%20%20%20%20%20print(f%22False%20Negatives%3A%20%7Bfn%7D%2C%20True%20Positives%3A%20%7Btp%7D%22)%0A%0A%20%20%20%20%20%20%20%20%23%20Plot%20heatmap%0A%20%20%20%20%20%20%20%20plt.figure(figsize%3D(8%2C%206))%0A%20%20%20%20%20%20%20%20sns.heatmap(%0A%20%20%20%20%20%20%20%20%20%20%20%20cm%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20annot%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20fmt%3D'd'%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20cmap%3D'Blues'%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20xticklabels%3D%5B'Normal%20User'%2C%20'Bot%2FSpam'%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20yticklabels%3D%5B'Normal%20User'%2C%20'Bot%2FSpam'%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20cbar_kws%3D%7B'label'%3A%20'Count'%7D%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20plt.xlabel('Predicted%20Label')%0A%20%20%20%20%20%20%20%20plt.ylabel('True%20Label')%0A%20%20%20%20%20%20%20%20plt.title('Confusion%20Matrix%20-%20Spam%20Detection%20(Threshold%3A%200.8)')%0A%20%20%20%20%20%20%20%20plt.tight_layout()%0A%0A%20%20%20%20%20%20%20%20%23%20Add%20text%20annotations%20for%20better%20interpretation%0A%20%20%20%20%20%20%20%20plt.text(0.5%2C%20-0.1%2C%20f'Accuracy%3A%20%7Bacc%3A.3f%7D%20%7C%20Precision%3A%20%7Bprecision%3A.3f%7D%20%7C%20Recall%3A%20%7Brecall%3A.3f%7D%20%7C%20F1%3A%20%7Bf1%3A.3f%7D'%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ha%3D'center'%2C%20va%3D'top'%2C%20transform%3Dplt.gca().transAxes%2C%20fontsize%3D10)%0A%0A%20%20%20%20%20%20%20%20%23%20Optional%3A%20Also%20create%20a%20probability%20distribution%20plot%0A%20%20%20%20%20%20%20%20def%20plot_probability_distribution()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Separate%20probabilities%20by%20true%20labels%0A%20%20%20%20%20%20%20%20%20%20%20%20normal_probs%20%3D%20%5Bprob%20for%20prob%2C%20label%20in%20zip(all_probs%2C%20all_labels)%20if%20label%20%3D%3D%200%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20spam_probs%20%3D%20%5Bprob%20for%20prob%2C%20label%20in%20zip(all_probs%2C%20all_labels)%20if%20label%20%3D%3D%201%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.figure(figsize%3D(10%2C%206))%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.hist(normal_probs%2C%20bins%3D50%2C%20alpha%3D0.7%2C%20label%3D'Normal%20Users'%2C%20color%3D'blue'%2C%20density%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.hist(spam_probs%2C%20bins%3D50%2C%20alpha%3D0.7%2C%20label%3D'Bot%2FSpam'%2C%20color%3D'red'%2C%20density%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.axvline(x%3D0.5%2C%20color%3D'green'%2C%20linestyle%3D'--'%2C%20label%3D'Default%20Threshold%20(0.5)')%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.axvline(x%3D0.8%2C%20color%3D'orange'%2C%20linestyle%3D'--'%2C%20label%3D'Current%20Threshold%20(0.8)')%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.xlabel('Predicted%20Bot%20Probability')%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.ylabel('Density')%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.title('Distribution%20of%20Predicted%20Probabilities%20by%20True%20Label')%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.legend()%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.grid(True%2C%20alpha%3D0.3)%0A%20%20%20%20%20%20%20%20%20%20%20%20plt.tight_layout()%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20plt.gca()%0A%0A%20%20%20%20%20%20%20%20%23%20Uncomment%20to%20see%20probability%20distribution%0A%20%20%20%20%20%20%20%20return%20(plot_probability_distribution()%2C%20plt.gca())%0A%0A%20%20%20%20_()%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%20Evaluation%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(device%2C%20model%2C%20tokenizer%2C%20torch)%3A%0A%20%20%20%20def%20_()%3A%0A%20%20%20%20%20%20%20%23%20Evaluate%20a%20single%20user%20input%20comment%0A%20%20%20%20%20%20%20model.eval()%0A%20%20%20%20%20%20%20author_name%20%3D%20input(%22Enter%20the%20author%20name%3A%20%22)%0A%20%20%20%20%20%20%20comment%20%3D%20input(%22Enter%20a%20YouTube%20comment%20to%20evaluate%3A%20%22)%0A%0A%20%20%20%20%20%20%20%23%20Tokenize%20author%20name%0A%20%20%20%20%20%20%20name_encoding%20%3D%20tokenizer(%0A%20%20%20%20%20%20%20%20%20%20%20author_name%2C%0A%20%20%20%20%20%20%20%20%20%20%20truncation%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20padding%3D%22max_length%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20max_length%3D128%2C%0A%20%20%20%20%20%20%20%20%20%20%20return_tensors%3D%22pt%22%0A%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20name_input_ids%20%3D%20name_encoding%5B%22input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20name_attention_mask%20%3D%20name_encoding%5B%22attention_mask%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%23%20Tokenize%20content%0A%20%20%20%20%20%20%20content_encoding%20%3D%20tokenizer(%0A%20%20%20%20%20%20%20%20%20%20%20comment%2C%0A%20%20%20%20%20%20%20%20%20%20%20truncation%3DTrue%2C%0A%20%20%20%20%20%20%20%20%20%20%20padding%3D%22max_length%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20max_length%3D128%2C%0A%20%20%20%20%20%20%20%20%20%20%20return_tensors%3D%22pt%22%0A%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20content_input_ids%20%3D%20content_encoding%5B%22input_ids%22%5D.to(device)%0A%20%20%20%20%20%20%20content_attention_mask%20%3D%20content_encoding%5B%22attention_mask%22%5D.to(device)%0A%0A%20%20%20%20%20%20%20%23%20Get%20prediction%20using%20model's%20built-in%20sigmoid%20and%20threshold%0A%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20prediction%20%3D%20model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_input_ids%3Dname_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_input_ids%3Dcontent_input_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20name_attention_mask%3Dname_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20content_attention_mask%3Dcontent_attention_mask%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_logits%3DFalse%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return_probs%3Dinput(%22Return%20probabilities%3F%20(y%2Fn)%3A%20%22)%20%3D%3D%20%22y%22%0A%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%23%20Print%20result%0A%20%20%20%20%20%20%20print(f%22Prediction%3A%20%7Bprediction%7D%22)%0A%0A%20%20%20%20_()%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
16a785acfc1523aa8260da93fcefc88cc2db64cbb1f66105347244cf3491762a