Browse Source

using weak pointers wherever possible

logicp 5 years ago
parent
commit
bd20677cd5
4 changed files with 73 additions and 56 deletions
  1. 1 1
      headers/send_interface.h
  2. 12 12
      headers/socket_listener.h
  3. 7 0
      main.cpp
  4. 53 43
      socket_listener.cpp

+ 1 - 1
headers/send_interface.h

@@ -7,7 +7,7 @@
 class SendInterface {
  public:
   virtual void sendMessage(int client_socket_fd,
-                           std::shared_ptr<char[]> message) = 0;
+                           std::weak_ptr<char[]> w_buffer_ptr) = 0;
 };
 
 #endif  // __SEND_INTERFACE_H__

+ 12 - 12
headers/socket_listener.h

@@ -32,7 +32,7 @@ class SocketListener : public SendInterface {
     std::function<void()> m_cb;
   };
   // constructor
-  SocketListener(std::string ipAddress, int port);
+  SocketListener(std::string ip_address, int port);
 
   // destructor
   ~SocketListener();
@@ -43,7 +43,7 @@ class SocketListener : public SendInterface {
    * @param[in] {std::string} The message to be sent
    */
   virtual void sendMessage(int client_socket_fd,
-                           std::shared_ptr<char[]> buffer) override;
+                           std::weak_ptr<char[]> w_buffer_ptr) override;
 
   MessageHandler createMessageHandler(std::function<void()> cb);
   /**
@@ -69,31 +69,31 @@ class SocketListener : public SendInterface {
 
   int waitForConnection(int listening);
 
-  void loop_check();
+  void loopCheck();
 
   void done();
 
-  void handle_loop();
+  void handleLoop();
 
   void detachThreads();
 
-  void push_to_queue(std::function<void()> fn);
+  void pushToQueue(std::function<void()> fn);
 
-  void handle_client_socket(int client_socket_fd,
-                            SocketListener::MessageHandler message_handler,
-                            std::shared_ptr<char[]> buf);
+  void handleClientSocket(int client_socket_fd,
+                          SocketListener::MessageHandler message_handler,
+                          const std::shared_ptr<char[]>& s_buffer_ptr);
 
-  // private members
+  /* private members */
+  // Server arguments
   std::string m_ip_address;
   int m_port;
+
   std::thread m_loop_thread;
-  std::queue<std::function<void()>> task_queue;
   std::mutex m_mutex_lock;
   std::condition_variable pool_condition;
   std::atomic<bool> accepting_tasks;
-  std::atomic<bool> shutdown_loop;
-  std::atomic<bool> m_loop_switch;
 
+  std::queue<std::function<void()>> task_queue;
   std::vector<std::thread> thread_pool;
 };
 

+ 7 - 0
main.cpp

@@ -3,6 +3,13 @@
 
 #include "headers/socket_listener.h"
 
+/** \mainpage
+ * SocketListener constructor takes 2 parameters (std::string ip, int port).
+ *
+ * Calling the "run()" method will cause it to for and handle multiple
+ * concurrent socket connections.
+ */
+
 int main(int argc, char** argv) {
   SocketListener server("0.0.0.0", 9009);
 

+ 53 - 43
socket_listener.cpp

@@ -20,7 +20,6 @@
 #include <string>
 #include <thread>
 #include <vector>
-
 int num_threads = std::thread::hardware_concurrency();
 
 /**
@@ -28,10 +27,7 @@ int num_threads = std::thread::hardware_concurrency();
  * Initialize with ip_address, port and message_handler
  */
 SocketListener::SocketListener(std::string ip_address, int port)
-    : m_ip_address(ip_address),
-      m_port(port),
-      accepting_tasks(true),
-      shutdown_loop(false) {}
+    : m_ip_address(ip_address), m_port(port), accepting_tasks(true) {}
 
 /**
  * Destructor
@@ -51,9 +47,15 @@ SocketListener::MessageHandler SocketListener::createMessageHandler(
  * pointer, to a client socket described by its file descriptor
  */
 void SocketListener::sendMessage(int client_socket_fd,
-                                 std::shared_ptr<char[]> s_ptr) {
-  send(client_socket_fd, s_ptr.get(), static_cast<size_t>(MAX_BUFFER_SIZE) + 1,
-       0);
+                                 std::weak_ptr<char[]> w_buffer_ptr) {
+  std::shared_ptr<char[]> s_buffer_ptr = w_buffer_ptr.lock();
+  if (s_buffer_ptr) {
+    send(client_socket_fd, s_buffer_ptr.get(),
+         static_cast<size_t>(MAX_BUFFER_SIZE) + 1, 0);
+  } else {
+    std::cout << "Could not send message to client " << client_socket_fd
+              << ". Buffer does not exist." << std::endl;
+  }
 }
 
 /**
@@ -65,14 +67,14 @@ bool SocketListener::init() {
   return true;
 }
 
-void SocketListener::push_to_queue(std::function<void()> fn) {
+void SocketListener::pushToQueue(std::function<void()> fn) {
   std::unique_lock<std::mutex> lock(m_mutex_lock);
   task_queue.push(fn);
   lock.unlock();
   pool_condition.notify_one();
 }
 
-void SocketListener::handle_loop() {
+void SocketListener::handleLoop() {
   std::string accepting_str = accepting_tasks == 0
                                   ? std::string("Not accepting tasks")
                                   : std::string("Accepting tasks");
@@ -95,16 +97,16 @@ void SocketListener::handle_loop() {
   }
 }
 
-void SocketListener::loop_check() {
-    for (int i = 0; i < task_queue.size() && i < (num_threads - 1); i++) {
-      thread_pool.push_back(std::thread([this]() { handle_loop(); }));
-    }
-    done();
-    std::this_thread::sleep_for(std::chrono::milliseconds(400));
-    detachThreads();
-    size_t task_num = task_queue.size();
-    std::cout << "Task num: " << task_num << std::endl;
-    accepting_tasks = true;
+void SocketListener::loopCheck() {
+  for (int i = 0; i < task_queue.size() && i < (num_threads - 1); i++) {
+    thread_pool.push_back(std::thread([this]() { handleLoop(); }));
+  }
+  done();
+  std::this_thread::sleep_for(std::chrono::milliseconds(400));
+  detachThreads();
+  size_t task_num = task_queue.size();
+  std::cout << "Task num: " << task_num << std::endl;
+  accepting_tasks = true;
 }
 
 void SocketListener::done() {
@@ -116,32 +118,35 @@ void SocketListener::done() {
   pool_condition.notify_all();
 }
 
-void SocketListener::handle_client_socket(
+void SocketListener::handleClientSocket(
     int client_socket_fd, SocketListener::MessageHandler message_handler,
-    std::shared_ptr<char[]> buf) {
-  while (true) {
-    memset(buf.get(), 0, MAX_BUFFER_SIZE);  // Zero the character buffer
+    const std::shared_ptr<char[]>& s_buffer_ptr) {
+  for (;;) {
+    memset(s_buffer_ptr.get(), 0,
+           MAX_BUFFER_SIZE);  // Zero the character buffer
     int bytes_received = 0;
     // Receive and write incoming data to buffer and return the number of
     // bytes received
     bytes_received =
-        recv(client_socket_fd, buf.get(),
+        recv(client_socket_fd, s_buffer_ptr.get(),
              MAX_BUFFER_SIZE - 2,  // Leave room for null-termination
              0);
-    buf.get()[MAX_BUFFER_SIZE - 1] = 0;  // Null-terminate the character buffer
+    s_buffer_ptr.get()[MAX_BUFFER_SIZE - 1] =
+        0;  // Null-terminate the character buffer
     if (bytes_received > 0) {
       std::cout << "Client " << client_socket_fd
                 << "\nBytes received: " << bytes_received
-                << "\nData: " << buf.get() << std::endl;
+                << "\nData: " << s_buffer_ptr.get() << std::endl;
       // Handle incoming message
       message_handler();
     } else {
-      std::cout << "Client " << client_socket_fd << " disconnected" << std::endl;
+      std::cout << "Client " << client_socket_fd << " disconnected"
+                << std::endl;
+      // Zero the buffer again before closing
+      memset(s_buffer_ptr.get(), 0, MAX_BUFFER_SIZE);
       break;
     }
   }
-  // Zero the buffer again before closing
-  memset(buf.get(), 0, MAX_BUFFER_SIZE);
   // TODO: Determine if we should free memory, or handle as class member
   close(client_socket_fd);  // Destroy client socket and deallocate its fd
 }
@@ -171,16 +176,21 @@ void SocketListener::run() {
       // Destroy listening socket and deallocate its file descriptor. Only use
       // the client socket now.
       close(listening_socket_fd);
-      std::shared_ptr<char[]> s_ptr(new char[MAX_BUFFER_SIZE]);
-      std::function<void()> message_send_fn = [this, client_socket_fd,
-                                               s_ptr]() {
-        this->sendMessage(client_socket_fd, s_ptr);
-      };
-      MessageHandler message_handler = createMessageHandler(message_send_fn);
-      std::cout << "Pushing client to queue" << std::endl;
-      push_to_queue(std::bind(&SocketListener::handle_client_socket, this,
-                              client_socket_fd, message_handler, s_ptr));
-      m_loop_thread = std::thread([this]() { loop_check(); });
+      {
+        std::shared_ptr<char[]> s_buffer_ptr(new char[MAX_BUFFER_SIZE]);
+        std::weak_ptr<char[]> w_buffer_ptr(s_buffer_ptr);
+        std::function<void()> message_send_fn = [this, client_socket_fd,
+                                                 w_buffer_ptr]() {
+          this->sendMessage(client_socket_fd, w_buffer_ptr);
+        };
+        MessageHandler message_handler = createMessageHandler(message_send_fn);
+        std::cout << "Pushing client to queue" << std::endl;
+        pushToQueue(
+            std::bind(&SocketListener::handleClientSocket, this,
+                      client_socket_fd, message_handler,
+                      std::forward<std::shared_ptr<char[]>>(s_buffer_ptr)));
+      }
+      m_loop_thread = std::thread([this]() { loopCheck(); });
       m_loop_thread.detach();
       accepting_tasks = false;
       std::cout << "At the end" << std::endl;
@@ -190,10 +200,10 @@ void SocketListener::run() {
 
 void SocketListener::detachThreads() {
   for (std::thread& t : thread_pool) {
-      if (t.joinable()) {
-        t.detach();
-      }
+    if (t.joinable()) {
+      t.detach();
     }
+  }
 }
 
 /**