modify.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from datasets import load_dataset
  2. ds = load_dataset("json", data_files="corrected_with_wang.json", split="train")
  3. query = [ds[i]['query'] for i in range(len(ds))]
  4. mapping={}
  5. mapping05={}
  6. mapping0={}
  7. for i in query:
  8. mapping[i]=[]
  9. mapping05[i]=[]
  10. mapping0[i]=[]
  11. for i in range(len(ds)):
  12. q = ds[i]['query']
  13. r = ds[i]['response']
  14. l = ds[i]['label']
  15. if l > 0.9:
  16. mapping[q].append(r)
  17. elif l > 0.4:
  18. mapping05[q].append(r)
  19. else:
  20. mapping0[q].append(r)
  21. with open("../hunningtu_rule", "r") as f:
  22. content = f.read()
  23. import json
  24. obj = json.loads(content)
  25. with open("../nantong_rule", "r") as f:
  26. content = f.read()
  27. obj2 = json.loads(content)
  28. for entry in obj2:
  29. print(entry)
  30. match = [("6.2.1.1","6.3.1.1"),
  31. ("6.2.1.2", "6.3.1.2"),
  32. ("6.2.1.3", "6.3.1.3"),
  33. ("6.2.1.4", "6.3.1.4"),
  34. ("6.2.1.5", "6.3.1.5"),
  35. ("6.2.1.6", "6.3.1.6"),
  36. ("6.2.2.1", "6.3.2.1"),
  37. ("6.2.2.2", "6.3.2.2"),
  38. ("6.2.3.1", "6.3.3.1"),
  39. ("6.2.3.2", "6.3.3.2"),
  40. ("6.2.3.3", "6.3.3.3"),
  41. ("6.2.3.4", "6.3.3.4"),
  42. ("6.2.3.5", "6.3.3.5"),
  43. ("6.2.3.6", "6.3.3.6")]
  44. match2=[
  45. ("1.1.7", "nantong1.1.7"),
  46. ("nantong2.1.2", "2.1.2"),
  47. ("nantong3.1.2", "3.1.2"),
  48. ("nantong3.1.4", "3.1.4"),
  49. ("nantong3.1.5", "3.1.5"),
  50. ("3.2.10", "nantong3.2.10"),
  51. ("nantong4.1.1", "4.1.1"),
  52. ("nantong4.1.2", "4.1.2"),
  53. ("nantong4.1.3", "4.1.3"),
  54. ("4.1.4", "nantong4.1.4"),
  55. ("4.1.5", "nantong4.1.5"),
  56. ("4.1.7", "nantong4.1.7"),
  57. ("4.4", "nantong4.4"),
  58. ("nantong6", "6"),
  59. ("7.5", "nantong7.5"),
  60. ("nantong7.8", "7.8"),
  61. ("10.1.5", "nantong10.1.5"),
  62. ("10.1.2", "nantong10.1.2"),
  63. ("10.1.1", "nantong10.1.1"),
  64. ("nantong10.1.1.2", "10.1.1.2"),
  65. ("10.1.1.3", "nantong10.1.1.3"),
  66. ("nantong11.1.2.1", "11.1.2.1"),
  67. ("nantong11.1.2.2", "11.1.2.2"),
  68. ("nantong11.1.1", "11.1.1"),
  69. ("12.7", "nantong12.7"),
  70. ("12.6", "nantong12.6"),
  71. ("nantong12.5", "12.5"),
  72. ("nantong13.1.1", "13.1.1"),
  73. ("nantong13.1.2" , "13.1.2"),
  74. ("nantong13.1.3", "13.1.3"),
  75. ("nantong13.2.2", "13.2.2"),
  76. ("nantong13.3.1", "13.3.1"),
  77. ("nantong13.3.2", "13.3.2"),
  78. ("13.3.3" ,"nantong13.3.3"),
  79. ("13.4.4", "nantong13.4.4"),
  80. ("nantong13.5.1", "13.5.1"),
  81. ("13.5.4", "nantong13.5.4"),
  82. ("nantong14.3.8", "14.3.8"),
  83. ("14.4.4", "nantong14.4.4"),
  84. ("14.4.6", "nantong14.4.6"),
  85. ("nantong15.3.1", "15.3.1"),
  86. ("16.2", "nantong16.2"),
  87. ("17.1.3.2", "nantong17.1.3.2"),
  88. ("17.1.3.3","nantong17.1.3.3"),
  89. ("17.1.3.4","nantong17.1.3.4"),
  90. ("18.3.3","nantong18.3.3"),
  91. ("18.3.2","nantong18.3.2"),
  92. ("18.5","nantong18.5"),
  93. ("18.6","nantong18.6"),
  94. ("18.15","nantong18.15"),
  95. ("20.1.1","nantong20.1.1"),
  96. ("20.1.2.1","nantong20.1.2.1"),
  97. ("20.1.2.3","nantong20.1.2.3"),
  98. ("20.1.2.5","nantong20.1.2.5"),
  99. ("21.1.1.1","nantong21.1.1.1"),
  100. ("21.1.1.2","nantong21.1.1.2"),
  101. ("21.1.3.1","nantong21.1.3.1"),
  102. ("21.1.3.2","nantong21.1.3.2"),
  103. ("21.1.3.3","nantong21.1.3.3"),
  104. ("21.1.5","nantong21.1.5"),
  105. ("21.1.6","nantong21.1.6"),
  106. ("21.1.7","nantong21.1.7"),
  107. ("23.1.2","nantong23.1.2")
  108. ]
  109. for q in mapping:
  110. correct = mapping[q]
  111. new05=[]
  112. for c in correct:
  113. for tuple_ in match:
  114. left = tuple_[0]
  115. right = tuple_[1]
  116. t = obj[left]+obj[right]
  117. if c in obj[left] or c in obj[right]:
  118. new05 = new05 + t
  119. for tuple_ in match2:
  120. left = tuple_[0]
  121. right = tuple_[1]
  122. t = obj2[left]+obj2[right]
  123. if c in obj2[left] or c in obj2[right]:
  124. new05 = new05 + t
  125. new05 = [x for x in new05 if x not in correct]
  126. zero = mapping0[q]
  127. zero2 = [x for x in zero if x not in new05]
  128. old05 = mapping05[q]
  129. old05_u = [x for x in old05 if x not in new05]
  130. old05_u = old05_u + new05
  131. mapping05[q]=old05_u
  132. mapping0[q]=zero2
  133. query=[]
  134. response=[]
  135. label=[]
  136. for i in mapping:
  137. for j in mapping[i]:
  138. query.append(i)
  139. response.append(j)
  140. label.append(1)
  141. for i in mapping05:
  142. for j in mapping05[i]:
  143. query.append(i)
  144. response.append(j)
  145. label.append(0.5)
  146. for i in mapping0:
  147. for j in mapping0[i]:
  148. query.append(i)
  149. response.append(j)
  150. label.append(0)
  151. from datasets import Dataset
  152. Dataset.from_dict({"query": query, "response": response, "label": label}).to_json("corrected_with_wang_modified.json")
  153. query=[]
  154. docs=[]
  155. labels=[]
  156. import random
  157. for i in mapping:
  158. query.append(i)
  159. one=list(set(mapping[i]))
  160. zero = list(set(mapping0[i]))
  161. half = list(set(mapping05[i]))
  162. if len(one) + len(zero) + len(half) < 16:
  163. pass
  164. else:
  165. if len(zero) >= len(half):
  166. zero = random.sample(zero, len(half))
  167. else:
  168. half = random.sample(half, len(zero))
  169. target = 16 - len(one)
  170. target = target // 2
  171. zero = random.sample(zero, target)
  172. half = random.sample(half, target)
  173. doc=[]
  174. label=[]
  175. for entry in one:
  176. doc.append(entry)
  177. label.append(1)
  178. for entry in half:
  179. if random.random()>0:
  180. doc.append(entry)
  181. label.append(0.5)
  182. for entry in zero:
  183. if random.random() > 0:
  184. doc.append(entry)
  185. label.append(0)
  186. docs.append(doc)
  187. labels.append(label)
  188. print(i)
  189. print(len(label))
  190. Dataset.from_dict({"query": query, "docs": docs, "labels": labels}).to_json("corrected_with_wang_lambda.json")