importosimportglobimportjsonfromdotenvimportload_dotenvfromtransformersimportAutoTokenizer,AutoModelForCausalLM,Trainer,TrainingArguments,DataCollatorForLanguageModeling,LlamaTokenizerFastfromdatasetsimportDataset,load_datasetimporttorchload_dotenv()MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"# Changed to the specified model
OUTPUT_DIR="trained_model"TRAIN_FILE="train.jsonl"MAX_LENGTH=512BATCH_SIZE=8EPOCHS=3defcreate_training_data(posts_dir):all_texts=[]forlang_dirinos.listdir(posts_dir):lang_path=os.path.join(posts_dir,lang_dir)ifnotos.path.isdir(lang_path):continueforfile_pathinglob.glob(os.path.join(lang_path,"*.md")):try:withopen(file_path,'r',encoding='utf-8')asf:content=f.read()# Remove front matter
content=content.split("---",2)[-1].strip()all_texts.append(content)exceptExceptionase:print(f"Error reading file {file_path}: {e}")returnall_textsdefprepare_dataset(texts,tokenizer):encodings=tokenizer(texts,truncation=True,padding=True,max_length=MAX_LENGTH,return_tensors="pt")returnDataset.from_dict(encodings)deftrain_model(dataset,tokenizer):training_args=TrainingArguments(output_dir=OUTPUT_DIR,overwrite_output_dir=True,num_train_epochs=EPOCHS,per_device_train_batch_size=BATCH_SIZE,save_steps=10_000,save_total_limit=2,prediction_loss_only=True,remove_unused_columns=False,)model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,trust_remote_code=True)data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm=False)trainer=Trainer(model=model,args=training_args,train_dataset=dataset,data_collator=data_collator,)trainer.train()trainer.save_model(OUTPUT_DIR)defmain():posts_dir="_posts"texts=create_training_data(posts_dir)tokenizer=LlamaTokenizerFast.from_pretrained(MODEL_NAME,trust_remote_code=True,use_fast=True)tokenizer.pad_token=tokenizer.eos_tokendataset=prepare_dataset(texts,tokenizer)train_model(dataset,tokenizer)if__name__=="__main__":main()