En användning för multipla trådar som blir allt vanligare är att utnyttja flera processoror och kärnor för beräkningsintensiva applikationer. Ett område som jag brinner för är AI och maskinlärande, där algoritmerna oftast är CPU-intensiva. Genom att köra på flera processorer parallellt så blir beräkningen klar snabbare.
I denna artikel får Ni känna lite på maskinlärande, med 2 implementationer av en kNN klassificerare, en enkeltrådad och en multitrådad. Under resan får vi även bekanta oss med Strategy och Template Method, två kraftfulla designmönster samt definiera och använda generiska typer, ämnen som också behandlas på kursen. Artikeln visar en hel del javakod men bara de väsentligaste delarna. Om du vill se hela koden kan du ladda ner projektet här.
Maskinlärande
En vanlig uppgift inom maskinlärande är klassifikation. Utifrån ett antal instanser, eller exempel, skapar man en klassificerare som kan användas för att kategorisera nya instanser. Kreditbedömning, medicinsk diagnos och textanalys är bara några exempel. Neurala nätverk, Naive Bayes, SVM, kNN är några olika typer av klassificerare.Exempeldata
Vi tar kreditbedömning vid sms lån som exempel. I tabellen visas två tidigare ansökningar och dess utfall. Bedömningen kanske har skett manuellt tidigare men med ett stort antal exempel så kan vi bygga en beslutsmodell som avgör nya ansökningar.Kön | Årsinkomst | Lånebelopp | Operatör | Veckodag | Timme | Beviljat |
man | 190.000 | 800 | telenor | lördag | 3 | nej |
kvinna | 240.000 | 400 | telia | tisdag | 17 | ja |
Många ramverk för maskinlärande representerar alla attribut som en array av double. Vi håller oss till objektorienterad representation och definierar därför en klass för en låneansökan:
kNN - Nearest Neighbor
kNN är en algoritm som är lätt att förstå och lätt att implementera. Idén bygger på att man jämför instansen som skall klassificeras med ett stort antal tidigare instanser och tilldelar samma klass som den (k = 1) eller de (k > 1) som är mest lika. Till detta behöver vi en funktion som kan beräkna hur lika två instanser är. Denna funktion är specifik för varje typ av instans och definieras genom att implementera följande interface:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public interface DistanceFunction<T> { | |
double calculateDistance(T item, T item2); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public double calculateDistance(LoanApplication item, LoanApplication item2) { | |
double distance = 0; | |
//nominal attributes | |
if (item.getCarrier() != item2.getCarrier()) distance += 1; | |
if (item.getGender() != item2.getGender()) distance += 1; | |
int dayDiff = Math.abs(item.getDayOfWeek() - item2.getDayOfWeek()); | |
if (dayDiff > 3) { | |
dayDiff -= 1 + 2 * (dayDiff - 4); | |
} | |
distance += dayDiff * dayDiff; | |
int timeOfDayDiff = Math.abs(item.getTimeOfDay() - item2.getTimeOfDay()); | |
if(timeOfDayDiff > 11) { | |
timeOfDayDiff -= 1 + 2 * (timeOfDayDiff - 12); | |
} | |
distance += timeOfDayDiff * timeOfDayDiff; | |
//scaled numeric range attributes | |
//TODO: tune scale factors | |
int requestedAmountDiff = item.getRequestedAmount() - item2.getRequestedAmount(); | |
requestedAmountDiff /= 100; | |
distance += requestedAmountDiff * requestedAmountDiff; | |
int yearlyIncomeDiff = item.getYearlyIncome() - item2.getYearlyIncome(); | |
yearlyIncomeDiff /= 10000; | |
distance += yearlyIncomeDiff * yearlyIncomeDiff; | |
return Math.sqrt(distance); | |
} |
Här är själva klassificeringsmetoden för kNN, observera att loopen som beräknar alla avstånd går på en enda tråd.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public double classify(T item) { | |
int numItems = items.size(); | |
// calculated distances in same order as the items list | |
double[] distances = new double[numItems]; | |
//calculate the distance to each instance | |
for (int i = 0; i < numItems; i++) { | |
distances[i] = distanceFunction.calculateDistance(item, items.get(i)); | |
} | |
//Find the k nearest items | |
List<T> nearestItems = getNearestItems(distances); | |
double result = averageFunction.calculateAverage(nearestItems); | |
return result; | |
} |
Nästa steg är att göra beräkningen av alla avstånd på flera processorer. Ett sätt att implementera det är att bryta ut beräkningen till en egen metod (calculateDistances) som sedan överrids i en subklass. Om vi samtidigt lyfter upp den utbrutna metoden samt classify-implementation i en abstrakt basklass, så har vi ett exempel på designmönstret Template Method. Nedan följer de intressanta metoderna i ParallelNearestNeighborClassifier:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Thread pool from the Java 1.5 Executor Framework | |
*/ | |
private ExecutorService executorService; | |
/** | |
* Initialize the thread pool | |
*/ | |
private void init() { | |
int numThreads = Runtime.getRuntime().availableProcessors(); | |
executorService = Executors.newFixedThreadPool(numThreads); | |
} | |
protected double[] calculateDistances(final T item) { | |
int numThreads = Runtime.getRuntime().availableProcessors(); | |
int numItems = items.size(); | |
final double[] distances = new double[numItems]; | |
final int numItemsPerThread = numItems / numThreads; | |
//Create the tasks to be executed by the executor service | |
List<Callable<Void>> tasks = new ArrayList<Callable<Void>>(numThreads); | |
for (int i = 0; i < numThreads;i++) { | |
final int j = i; | |
tasks.add( new Callable<Void>() { | |
public Void call() { | |
calculateSegment(distances,item, j * numItemsPerThread, numItemsPerThread); | |
return null; | |
} | |
}); | |
} | |
//run all tasks, waiting for them to complete | |
try { | |
executorService.invokeAll(tasks); | |
} catch (InterruptedException e) { | |
e.printStackTrace(); | |
} | |
return distances; | |
} | |
/** | |
* Calculates distances for a specific slice of the items | |
* @param A reference to the entire array of distances | |
* @param item The item to compare with | |
* @param offset index of the first item to compare | |
* @param count number of items to compare | |
*/ | |
private void calculateSegment(double[] distances, T item, int offset, int count) { | |
for(int i = offset; i < offset + count; i++) { | |
//avoid array out of bounds when items not evenly | |
//divisible by number of threads | |
if (i == distances.length) break; | |
distances[i] = distanceFunction.calculateDistance(item, items.get(i)); | |
} | |
} |
Och slutligen ett test-program som jämför båda algoritmerna med olika mängder data.
Resultat
Av tabellen framgår att den multitrådade versionen går lite mer än dubbelt så fort utom för väldigt små datamängder. Min maskin har 8 kärnor och jag körde applikationen i en VMware instans med 4 kärnor tilldelade. I teorin skall det gå 4 gånger fortare men jag kom aldrig riktigt upp i fullt tryck (100% CPU) på alla 4 kärnor, vilket kan bero på den virtuella miljön. Som med alla typer av optimering är det bra att mäta och göra val utifrån faktisk prestanda och inte teoretiskt resonemang.Antal instanser | En tråd (ms) | Flera trådar (ms) |
5000 | 102 | 73 |
10000 | 91 | 85 |
50000 | 376 | 251 |
100000 | 781 | 385 |
1000000 | 7681 | 3353 |
5000000 | 39376 | 16893 |
10000000 | 83029 | 34699 |
PS. AI och maskinlärande är inget som tas upp på kursen, men om det finns intresse så kan vi ordna en workshop. Kontakta din lokala säljare!
1 kommentarer :
Uppdaterade posten, gistarna laddade inte!
SvaraSkicka en kommentar
Trevligt att du vill dela med dig av dina åsikter! Tänk på att hålla på "Netiketten" och använda vårdat språk.