tasks.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from celery_app import celery_app
  2. import json
  3. import chromadb
  4. client = chromadb.HttpClient(host='47.101.198.30',port=8000)
  5. collection = client.get_or_create_collection(name="tj_de_bge")
  6. from FlagEmbedding import FlagModel
  7. model = FlagModel('/Users/zxp/Downloads/test2_encoder_only_base_bge-large-zh-v1.5')
  8. from sentence_transformers import CrossEncoder
  9. ce = CrossEncoder('/Users/zxp/Downloads/reranker')
  10. with open("hunningtu_rule", "r") as f:
  11. content = f.read()
  12. obj = json.loads(content)
  13. with open("nantong_rule", "r") as f:
  14. content = f.read()
  15. obj2 = json.loads(content)
  16. with open("basic_rule", "r") as f:
  17. content = f.read()
  18. basic = json.loads(content)
  19. with open("incremental_rule", "r") as f:
  20. content = f.read()
  21. incremental = json.loads(content)
  22. with open("label_name", "r") as f:
  23. content = f.read()
  24. label_name = json.loads(content)
  25. THRESHOLD=0.9####adjust it
  26. @celery_app.task
  27. def process_data(data:dict)-> dict:
  28. label = data['mc'] + ' ' + data['tz']
  29. sentences = [label]
  30. embeddings = model.encode(sentences)
  31. result = collection.query(query_embeddings=embeddings,n_results=25)
  32. d = result['documents'][0]
  33. print(d)
  34. ranks = ce.rank(label, d)
  35. ranks = ranks[:10]
  36. match = [("6.2.1.1","6.3.1.1"),
  37. ("6.2.1.2", "6.3.1.2"),
  38. ("6.2.1.3", "6.3.1.3"),
  39. ("6.2.1.4", "6.3.1.4"),
  40. ("6.2.1.5", "6.3.1.5"),
  41. ("6.2.1.6", "6.3.1.6"),
  42. ("6.2.2.1", "6.3.2.1"),
  43. ("6.2.2.2", "6.3.2.2"),
  44. ("6.2.3.1", "6.3.3.1"),
  45. ("6.2.3.2", "6.3.3.2"),
  46. ("6.2.3.3", "6.3.3.3"),
  47. ("6.2.3.4", "6.3.3.4"),
  48. ("6.2.3.5", "6.3.3.5"),
  49. ("6.2.3.6", "6.3.3.6")]
  50. match2=[
  51. ("1.1.7", "nantong1.1.7"),
  52. ("nantong2.1.2", "2.1.2"),
  53. ("nantong3.1.2", "3.1.2"),
  54. ("nantong3.1.4", "3.1.4"),
  55. ("nantong3.1.5", "3.1.5"),
  56. ("3.2.10", "nantong3.2.10"),
  57. ("nantong4.1.1", "4.1.1"),
  58. ("nantong4.1.2", "4.1.2"),
  59. ("nantong4.1.3", "4.1.3"),
  60. ("4.1.4", "nantong4.1.4"),
  61. ("4.1.5", "nantong4.1.5"),
  62. ("4.1.7", "nantong4.1.7"),
  63. ("4.4", "nantong4.4"),
  64. ("nantong6", "6"),
  65. ("7.5", "nantong7.5"),
  66. ("nantong7.8", "7.8"),
  67. ("10.1.5", "nantong10.1.5"),
  68. ("10.1.2", "nantong10.1.2"),
  69. ("10.1.1", "nantong10.1.1"),
  70. ("nantong10.1.1.2", "10.1.1.2"),
  71. ("10.1.1.3", "nantong10.1.1.3"),
  72. ("nantong11.1.2.1", "11.1.2.1"),
  73. ("nantong11.1.2.2", "11.1.2.2"),
  74. ("nantong11.1.1", "11.1.1"),
  75. ("12.7", "nantong12.7"),
  76. ("12.6", "nantong12.6"),
  77. ("nantong12.5", "12.5"),
  78. ("nantong13.1.1", "13.1.1"),
  79. ("nantong13.1.2" , "13.1.2"),
  80. ("nantong13.1.3", "13.1.3"),
  81. ("nantong13.2.2", "13.2.2"),
  82. ("nantong13.3.1", "13.3.1"),
  83. ("nantong13.3.2", "13.3.2"),
  84. ("13.3.3" ,"nantong13.3.3"),
  85. ("13.4.4", "nantong13.4.4"),
  86. ("nantong13.5.1", "13.5.1"),
  87. ("13.5.4", "nantong13.5.4"),
  88. ("nantong14.3.8", "14.3.8"),
  89. ("14.4.4", "nantong14.4.4"),
  90. ("14.4.6", "nantong14.4.6"),
  91. ("nantong15.3.1", "15.3.1"),
  92. ("16.2", "nantong16.2"),
  93. ("17.1.3.2", "nantong17.1.3.2"),
  94. ("17.1.3.3","nantong17.1.3.3"),
  95. ("17.1.3.4","nantong17.1.3.4"),
  96. ("18.3.3","nantong18.3.3"),
  97. ("18.3.2","nantong18.3.2"),
  98. ("18.5","nantong18.5"),
  99. ("18.6","nantong18.6"),
  100. ("18.15","nantong18.15"),
  101. ("20.1.1","nantong20.1.1"),
  102. ("20.1.2.1","nantong20.1.2.1"),
  103. ("20.1.2.3","nantong20.1.2.3"),
  104. ("20.1.2.5","nantong20.1.2.5"),
  105. ("21.1.1.1","nantong21.1.1.1"),
  106. ("21.1.1.2","nantong21.1.1.2"),
  107. ("21.1.3.1","nantong21.1.3.1"),
  108. ("21.1.3.2","nantong21.1.3.2"),
  109. ("21.1.3.3","nantong21.1.3.3"),
  110. ("21.1.5","nantong21.1.5"),
  111. ("21.1.6","nantong21.1.6"),
  112. ("21.1.7","nantong21.1.7"),
  113. ("23.1.2","nantong23.1.2")
  114. ]
  115. selected=[]
  116. notselected=[]
  117. for entry in incremental:
  118. notselected = notselected + incremental[entry]
  119. for rank in ranks:
  120. if rank['score']<THRESHOLD:
  121. continue
  122. if d[rank['corpus_id']] in notselected:
  123. continue
  124. print(f"{rank['score']} {d[rank['corpus_id']]}")
  125. selected.append(d[rank['corpus_id']])
  126. hunningtu_group = []
  127. for entry in obj:
  128. if d[rank['corpus_id']] in obj[entry]:
  129. hunningtu_group=[entry]
  130. if len(hunningtu_group) > 0:
  131. for entry in match:
  132. if entry[0]==hunningtu_group[0] or entry[1] == hunningtu_group[0]:
  133. notselected = notselected + obj[entry[0]]
  134. notselected = notselected + obj[entry[1]]
  135. nantong_group = []
  136. for entry in obj2:
  137. if d[rank['corpus_id']] in obj2[entry]:
  138. nantong_group=[entry]
  139. if len(nantong_group) > 0:
  140. for entry in match2:
  141. if entry[0]==nantong_group[0] or entry[1] == nantong_group[0]:
  142. notselected = notselected + obj2[entry[0]]
  143. notselected = notselected + obj2[entry[1]]
  144. for entry in basic:
  145. if d[rank['corpus_id']] in basic[entry]:
  146. notselected = notselected + basic[entry]
  147. notselected = [x for x in notselected if x not in selected]
  148. result = [label_name[x] for x in selected]
  149. return {"result": result}