Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file bin_pred.ml
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171typeconfusion_matrix={tp:int;tn:int;fp:int;fn:int;}[@@derivingsexp]typecurve=(float*confusion_matrix)array[@@derivingsexp]letzero={tp=0;tn=0;fn=0;fp=0}letupdateaccu~threshold~score~label=matchFloat.(threshold<score),labelwith|true,true->{accuwithtp=accu.tp+1}|true,false->{accuwithfp=accu.fp+1}|false,true->{accuwithfn=accu.fn+1}|false,false->{accuwithtn=accu.tn+1}letconfusion_matrix~scores~labels~threshold=ifArray.lengthscores<>Array.lengthlabelstheninvalid_argf"Bin_pred.confusion_matrix: scores and labels have different lengths (%d and %d)"(Array.lengthscores)(Array.lengthlabels)();Array.fold2_exnscoreslabels~init:zero~f:(funaccuscorelabel->updateaccu~threshold~score~label)letpositivecm=cm.tp+cm.fnletnegativecm=cm.tn+cm.fpletcardinalcm=cm.tp+cm.tn+cm.fp+cm.fnletno_positives{tp;fn=_;fp;tn=_}=tp=0&&fp=0letsensitivitycm=floatcm.tp/.float(cm.tp+cm.fn)letrecall=sensitivityletfalse_positive_ratecm=floatcm.fp/.float(cm.fp+cm.tn)letaccuracycm=float(cm.tp+cm.tn)/.float(cardinalcm)letspecificitycm=floatcm.tn/.float(cm.fp+cm.tn)letpositive_predictive_valuecm=floatcm.tp/.float(cm.tp+cm.fp)letprecision=positive_predictive_valueletnegative_predictive_valuecm=floatcm.tn/.float(cm.tn+cm.fn)letfalse_discovery_ratecm=floatcm.fp/.float(cm.fp+cm.tp)letf1_scorecm=2.*.floatcm.tp/.float(2*cm.tp+cm.fp+cm.fn)letperformance_curve~scores~labels=letn=Array.lengthscoresinifn<>Array.lengthlabelstheninvalid_argf"Bin_pred.make_curve: scores and labels have different lengths (%d and %d)"n(Array.lengthlabels)();letexamples=letr=Array.map2_exnscoreslabels~f:(funxy->x,y)inArray.sort~compare:(Fn.flipPoly.compare)r;rinletnp=Array.countlabels~f:Fun.idinletnn=Array.countlabels~f:(funx->notx)inletinitial={tp=0;tn=nn;fp=0;fn=np}inletrecloopacccurrent_thresholdcurrent_matrixi=ifi=nthenList.rev((current_threshold,current_matrix)::acc)elseletscore,label=examples.(i)inletacc=ifFloat.(score<current_threshold)then(current_threshold,current_matrix)::accelseaccinletnew_matrix=iflabelthen{current_matrixwithtp=current_matrix.tp+1;fn=current_matrix.fn-1}else{current_matrixwithfp=current_matrix.fp+1;tn=current_matrix.tn-1}inloopaccscorenew_matrix(i+1)inloop[]Float.infinityinitial0|>Array.of_listlettrapez_areax1x2y1y2=0.5*.(y1+.y2)*.(x2-.x1)(* Assumes [points] is non empty (which is the case if it has been
produced by [performance_curve]) and that points come with
decreasing x-coordinates. *)letaucpoints=letf((x1,y1),sum)((x2,y2)asp)=(p,sum+.trapez_areax1x2y1y2)inArray.foldpoints~f~init:(points.(0),0.)|>sndletroc_curve~scores~labels=letmatrices=performance_curve~scores~labelsinletcurve=Array.mapmatrices~f:(fun(_,m)->false_positive_ratem,sensitivitym)inletauc=auccurveincurve,aucletaverage_precision~precision~recall=letn=Array.lengthprecisioninArray.init(n-1)~f:(funi->precision.(i+1)*.(recall.(i+1)-.recall.(i)))|>Array.sum(moduleFloat)~f:Fn.idletrecall_precision_curve~scores~labels=letmatrices=performance_curve~scores~labelsinletcurve=Array.mapmatrices~f:(fun(_,m)->ifno_positivesmthen0.,1.elserecallm,precisionm)inletrecall,precision=Array.unzipcurveincurve,average_precision~recall~precisionlet%expect_test"performance curve 1"=letscores=[|2.1;1.2;5.6;0.|]inletlabels=[|true;false;true;false|]inletcurve=performance_curve~scores~labelsinprint_endline(Sexp.to_string_hum(sexp_of_curvecurve));[%expect"
((INF ((tp 0) (tn 2) (fp 0) (fn 2))) (5.6 ((tp 1) (tn 2) (fp 0) (fn 1)))
(2.1 ((tp 2) (tn 2) (fp 0) (fn 0))) (1.2 ((tp 2) (tn 1) (fp 1) (fn 0)))
(0 ((tp 2) (tn 0) (fp 2) (fn 0))))"]let%test"rp_curve perfect recognition"=letscores=[|2.1;1.2;5.6;0.|]inletlabels=[|true;false;true;false|]inlet_,auc=recall_precision_curve~scores~labelsinFloat.(auc=1.)let%test"rp_curve against sklearn"=letscores=[|-0.20078869;0.30423874;0.20105976;0.27523711;0.42593404;-0.15043726;-0.08794601;-0.12733462;0.22931596;-0.23913518;-0.06386267;-0.14958466;-0.04914839;0.09898417;0.0515638;-0.1142941;0.16106135;0.04871897;-0.08258102;-0.26105668;0.24693291;-0.18029058;-0.38384994;0.26336904;0.12585371;-0.03991278;0.39424539;0.42411536;-0.4790443;-0.30529061;-0.09281931;0.01213433;-0.20204098;0.40148935;-0.04536122;0.12179099;0.06493837;-0.07007139;0.0032915;-0.39635676;0.02619439;0.20018683;0.065023;0.49589616;-0.28221895;0.31364573;0.1906223;0.11549516;0.03145977;0.22408591|]inletlabels=[|true;true;true;true;true;false;true;false;true;false;false;true;false;false;true;false;false;false;true;false;true;false;false;true;true;true;true;true;false;false;false;true;false;true;false;true;false;false;false;false;true;true;true;true;false;true;true;false;true;false|]inletcurve,_=recall_precision_curve~scores~labelsinletrecall,precision=Array.unzipcurveinletap=average_precision~recall~precisioninFloat.robustly_compareap0.8783170534965226=0