fix long seq bug
This commit is contained in:
		
							parent
							
								
									a62ac2658f
								
							
						
					
					
						commit
						b9bc7eaf35
					
				|  | @ -22,4 +22,4 @@ log_interval = 50 | |||
| eval_interval = 500 | ||||
| save_interval = 1000 | ||||
| work_dir = "data/saved_ckpt/" | ||||
| ckpt_path = "data/saved_ckpt/30000.pt" | ||||
| ckpt_path = "data/saved_ckpt/40000.pt" | ||||
|  |  | |||
|  | @ -21,11 +21,11 @@ def create_data_iter(paths, transform_dict=None, process_index=0, num_processes= | |||
|     past = None | ||||
|     for i, path in paths: | ||||
|         dataset_name = path.split("-")[-2] | ||||
|         if num_processes > 1 and i % num_processes != process_index: | ||||
|             continue | ||||
|         if past != dataset_name: | ||||
|             print("Loading data from {}".format(path)) | ||||
|             past = path | ||||
|         if num_processes > 1 and i % num_processes != process_index: | ||||
|             continue | ||||
|         if path.endswith("jsonl.zst"): | ||||
|             with zstd.open(path, "r", encoding="utf-8") as fp: | ||||
|                 for line in fp: | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ Description: | |||
| 
 | ||||
| Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.  | ||||
| """ | ||||
| import math | ||||
| 
 | ||||
| 
 | ||||
| def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024): | ||||
|  | @ -23,7 +24,10 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024): | |||
|         total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip()) | ||||
|         out = tokenizer(total) | ||||
|         input_ids = out["input_ids"] | ||||
|         return [input_ids] | ||||
|         return [ | ||||
|             input_ids[i * segment_max_length : (i + 1) * segment_max_length] | ||||
|             for i in range(math.ceil(len(input_ids) / segment_max_length)) | ||||
|         ] | ||||
| 
 | ||||
|     return preprocess_self_instruction | ||||
| 
 | ||||
|  | @ -43,7 +47,10 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024): | |||
|         total = "user:{}<s>system:{}".format(prompt, completion) | ||||
|         out = tokenizer(total) | ||||
|         input_ids = out["input_ids"] | ||||
|         return [input_ids] | ||||
|         return [ | ||||
|             input_ids[i * segment_max_length : (i + 1) * segment_max_length] | ||||
|             for i in range(math.ceil(len(input_ids) / segment_max_length)) | ||||
|         ] | ||||
| 
 | ||||
|     return preprocess_belle | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 LiangSong
						LiangSong